From 9cd76b71abf15b31878f8d9675546f809a6ba150 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Mon, 8 Sep 2025 15:40:26 +0200 Subject: [PATCH] [Misc] Terratorch related fixes (#24337) Signed-off-by: Christian Pinto Co-authored-by: Cyrus Leung --- .../prithvi_geospatial_mae_io_processor.py | 4 ++-- .../online_serving/prithvi_geospatial_mae.py | 4 ++-- requirements/test.in | 2 +- requirements/test.txt | 2 +- .../entrypoints/openai/test_skip_tokenizer.py | 2 +- tests/models/registry.py | 4 ++-- tests/models/test_terratorch.py | 2 +- .../prithvi_io_processor/__init__.py | 6 ++---- .../prithvi_io_processor/prithvi_processor.py | 20 ++----------------- .../prithvi_io_processor_plugin/setup.py | 3 +-- .../test_io_processor_plugins.py | 6 +++--- 11 files changed, 18 insertions(+), 37 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 5d629fabf0..418c40645f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams def main(): torch.set_default_dtype(torch.float16) - image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 img_prompt = dict( data=image_url, @@ -36,7 +36,7 @@ def main(): # to avoid the model going OOM. # The maximum number depends on the available GPU memory max_num_seqs=32, - io_processor_plugin="prithvi_to_tiff_india", + io_processor_plugin="prithvi_to_tiff", model_impl="terratorch", ) diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index c6eed64838..611a7cbc89 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -18,11 +18,11 @@ import requests # --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager -# --io-processor-plugin prithvi_to_tiff_india +# --io-processor-plugin prithvi_to_tiff def main(): - image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 server_endpoint = "http://localhost:8000/pooling" request_payload_url = { diff --git a/requirements/test.in b/requirements/test.in index 5db9cd7979..1bbf0074a8 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -54,4 +54,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 decord==0.6.0 -terratorch==1.1rc3 # required for PrithviMAE test +terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test diff --git a/requirements/test.txt b/requirements/test.txt index 332a9b9cfb..65ef7c3c64 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1042,7 +1042,7 @@ tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch==1.1rc3 +terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index af520ac61d..840e0dac81 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -11,7 +11,7 @@ import torch from ...utils import RemoteOpenAIServer -MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" DTYPE = "float16" diff --git a/tests/models/registry.py b/tests/models/registry.py index c6ff50b542..e4c215b108 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -383,7 +383,7 @@ _EMBEDDING_EXAMPLE_MODELS = { "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 - "PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 dtype=torch.float16, enforce_eager=True, skip_tokenizer_init=True, @@ -391,7 +391,7 @@ _EMBEDDING_EXAMPLE_MODELS = { # going OOM in CI max_num_seqs=32, ), - "Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 dtype=torch.float16, enforce_eager=True, skip_tokenizer_init=True, diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py index bfa54280dc..d6d43ca2f7 100644 --- a/tests/models/test_terratorch.py +++ b/tests/models/test_terratorch.py @@ -11,7 +11,7 @@ from vllm.utils import set_default_torch_num_threads @pytest.mark.parametrize( "model", [ - "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "mgazz/Prithvi_v2_eo_300_tl_unet_agb" ], ) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py index a750c756c1..4bbb79c98a 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -def register_prithvi_india(): - return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501 -def register_prithvi_valencia(): - return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501 +def register_prithvi(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501 diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 0ebaafda94..42874f0398 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -234,6 +234,8 @@ def load_image( class PrithviMultimodalDataProcessor(IOProcessor): + indices = [0, 1, 2, 3, 4, 5] + def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) @@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor): format="tiff", data=out_data, request_id=request_id) - - -class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor): - - def __init__(self, vllm_config: VllmConfig): - - super().__init__(vllm_config) - - self.indices = [1, 2, 3, 8, 11, 12] - - -class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor): - - def __init__(self, vllm_config: VllmConfig): - - super().__init__(vllm_config) - - self.indices = [0, 1, 2, 3, 4, 5] diff --git a/tests/plugins/prithvi_io_processor_plugin/setup.py b/tests/plugins/prithvi_io_processor_plugin/setup.py index a03b1fbbd4..3ddda1a47b 100644 --- a/tests/plugins/prithvi_io_processor_plugin/setup.py +++ b/tests/plugins/prithvi_io_processor_plugin/setup.py @@ -9,8 +9,7 @@ setup( packages=["prithvi_io_processor"], entry_points={ "vllm.io_processor_plugins": [ - "prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501 - "prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501 + "prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501 ] }, ) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 825165e89b..3567a701a3 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 @@ -35,7 +35,7 @@ def server(): "--max-num-seqs", "32", "--io-processor-plugin", - "prithvi_to_tiff_valencia", + "prithvi_to_tiff", "--model-impl", "terratorch", ] @@ -107,7 +107,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): # to avoid the model going OOM in CI. max_num_seqs=1, model_impl="terratorch", - io_processor_plugin="prithvi_to_tiff_valencia", + io_processor_plugin="prithvi_to_tiff", ) as llm_runner: pooler_output = llm_runner.get_llm().encode( img_prompt,