[Core][Model] Terratorch backend integration (#23513)
Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -45,7 +45,11 @@ datamodule_config = {
|
||||
class PrithviMAE:
|
||||
def __init__(self, model):
|
||||
self.model = LLM(
|
||||
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
model_impl="terratorch",
|
||||
)
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
|
||||
@ -37,6 +37,7 @@ def main():
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="prithvi_to_tiff_india",
|
||||
model_impl="terratorch",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
||||
Reference in New Issue
Block a user