[Frontend][4/N] Improve all pooling task | Add plugin pooling task (#26973)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
wang.yuqi
2025-10-23 22:46:18 +08:00
committed by GitHub
parent fe2016de2d
commit 3fa2c12185
16 changed files with 102 additions and 54 deletions

View File

@ -64,7 +64,7 @@ class PrithviMAE:
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False)
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
return outputs[0].outputs.data

View File

@ -6,14 +6,14 @@ import os
import torch
from vllm import LLM
from vllm.pooling_params import PoolingParams
# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirement - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# Requirements:
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
def main():
@ -36,16 +36,12 @@ def main():
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff",
io_processor_plugin="terratorch_segmentation",
model_impl="terratorch",
enable_mm_embeds=True,
)
pooling_params = PoolingParams(task="token_classify", activation=False)
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs
print(output)

View File

@ -11,14 +11,14 @@ import requests
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirements :
# - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff
# --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds
@ -35,7 +35,6 @@ def main():
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
"softmax": False,
}
ret = requests.post(server_endpoint, json=request_payload_url)