[Misc] Terratorch related fixes (#24337)

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Christian Pinto
2025-09-08 15:40:26 +02:00
committed by GitHub
parent e041314184
commit 9cd76b71ab
11 changed files with 18 additions and 37 deletions

View File

@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
def main():
torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
img_prompt = dict(
data=image_url,
@ -36,7 +36,7 @@ def main():
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india",
io_processor_plugin="prithvi_to_tiff",
model_impl="terratorch",
)

View File

@ -18,11 +18,11 @@ import requests
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india
# --io-processor-plugin prithvi_to_tiff
def main():
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
server_endpoint = "http://localhost:8000/pooling"
request_payload_url = {

View File

@ -54,4 +54,4 @@ runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
decord==0.6.0
terratorch==1.1rc3 # required for PrithviMAE test
terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test

View File

@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
terratorch==1.1rc3
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn

View File

@ -11,7 +11,7 @@ import torch
from ...utils import RemoteOpenAIServer
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16"

View File

@ -383,7 +383,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
@ -391,7 +391,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,

View File

@ -11,7 +11,7 @@ from vllm.utils import set_default_torch_num_threads
@pytest.mark.parametrize(
"model",
[
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
],
)

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def register_prithvi_india():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
def register_prithvi_valencia():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501
def register_prithvi():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501

View File

@ -234,6 +234,8 @@ def load_image(
class PrithviMultimodalDataProcessor(IOProcessor):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor):
format="tiff",
data=out_data,
request_id=request_id)
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [1, 2, 3, 8, 11, 12]
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [0, 1, 2, 3, 4, 5]

View File

@ -9,8 +9,7 @@ setup(
packages=["prithvi_io_processor"],
entry_points={
"vllm.io_processor_plugins": [
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
"prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501
]
},
)

View File

@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
@ -35,7 +35,7 @@ def server():
"--max-num-seqs",
"32",
"--io-processor-plugin",
"prithvi_to_tiff_valencia",
"prithvi_to_tiff",
"--model-impl",
"terratorch",
]
@ -107,7 +107,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
# to avoid the model going OOM in CI.
max_num_seqs=1,
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff_valencia",
io_processor_plugin="prithvi_to_tiff",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,