[New Model]: Support GteNewModelForSequenceClassification (#23524)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@ -456,11 +456,10 @@ class HfRunner:
|
||||
# output is final logits
|
||||
all_inputs = self.get_inputs(prompts)
|
||||
outputs = []
|
||||
problem_type = getattr(self.config, "problem_type", "")
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model(**self.wrap_device(inputs))
|
||||
|
||||
problem_type = getattr(self.config, "problem_type", "")
|
||||
|
||||
if problem_type == "regression":
|
||||
logits = output.logits[0].tolist()
|
||||
elif problem_type == "multi_label_classification":
|
||||
|
||||
Reference in New Issue
Block a user