[Misc] Terratorch related fixes (#24337)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
|
||||
|
||||
def main():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
@ -36,7 +36,7 @@ def main():
|
||||
# to avoid the model going OOM.
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="prithvi_to_tiff_india",
|
||||
io_processor_plugin="prithvi_to_tiff",
|
||||
model_impl="terratorch",
|
||||
)
|
||||
|
||||
|
||||
@ -18,11 +18,11 @@ import requests
|
||||
# --model-impl terratorch
|
||||
# --task embed --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin prithvi_to_tiff_india
|
||||
# --io-processor-plugin prithvi_to_tiff
|
||||
|
||||
|
||||
def main():
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
server_endpoint = "http://localhost:8000/pooling"
|
||||
|
||||
request_payload_url = {
|
||||
|
||||
@ -54,4 +54,4 @@ runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
decord==0.6.0
|
||||
terratorch==1.1rc3 # required for PrithviMAE test
|
||||
terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test
|
||||
|
||||
@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
|
||||
# via lightning
|
||||
tensorizer==2.10.1
|
||||
# via -r requirements/test.in
|
||||
terratorch==1.1rc3
|
||||
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
|
||||
# via -r requirements/test.in
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
DTYPE = "float16"
|
||||
|
||||
|
||||
|
||||
@ -383,7 +383,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
@ -391,7 +391,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
# going OOM in CI
|
||||
max_num_seqs=32,
|
||||
),
|
||||
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
"Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.utils import set_default_torch_num_threads
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def register_prithvi_india():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
|
||||
|
||||
|
||||
def register_prithvi_valencia():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501
|
||||
def register_prithvi():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501
|
||||
|
||||
@ -234,6 +234,8 @@ def load_image(
|
||||
|
||||
class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
|
||||
indices = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
format="tiff",
|
||||
data=out_data,
|
||||
request_id=request_id)
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.indices = [1, 2, 3, 8, 11, 12]
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.indices = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
@ -9,8 +9,7 @@ setup(
|
||||
packages=["prithvi_io_processor"],
|
||||
entry_points={
|
||||
"vllm.io_processor_plugins": [
|
||||
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
|
||||
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
|
||||
"prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
@ -35,7 +35,7 @@ def server():
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
"prithvi_to_tiff_valencia",
|
||||
"prithvi_to_tiff",
|
||||
"--model-impl",
|
||||
"terratorch",
|
||||
]
|
||||
@ -107,7 +107,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
model_impl="terratorch",
|
||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||
io_processor_plugin="prithvi_to_tiff",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
|
||||
Reference in New Issue
Block a user