From 2048c4e37909a42847cd2f51c7e0cf92e3b63466 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 10 Sep 2025 23:53:24 -0700 Subject: [PATCH] [torchao] Support quantization configs using module swap (#21982) Signed-off-by: Jerry Zhang --- .buildkite/test-pipeline.yaml | 4 ++++ tests/quantization/test_torchao.py | 20 +++++++++++++++++++ .../layers/quantization/torchao.py | 16 ++++++++------- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 51fc9c46e6..35a849d70c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -507,6 +507,10 @@ steps: commands: # temporary install here since we need nightly, will move to requirements/test.in # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index eef3568efe..8e68f6a2e0 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now") +def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2" + "-0.14.0.dev") + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 63b2ab6bab..3498d2994c 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor, from torchao.quantization import quantize_ assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" - """ - Avoid real weight allocation for faster load, since we will + """ + Avoid real weight allocation for faster load, since we will end up setting it to param. """ with torch.device("meta"): - dummy_linear = torch.nn.Linear(param.shape[1], - param.shape[0], - bias=False) + # linear can't be top level module since quantize_ is inplace + # while some of our configs need to do module swap, and only non-top + # level modules support module swap + dummy_linear = torch.nn.Sequential( + torch.nn.Linear(param.shape[1], param.shape[0], bias=False)) - dummy_linear.weight = param + dummy_linear[0].weight = param quantize_(dummy_linear, torchao_config) - return dummy_linear.weight + return dummy_linear[0].weight class TorchAOLinearMethod(LinearMethodBase):