diff --git a/senta/train.py b/senta/train.py index f0c4056..714241e 100644 --- a/senta/train.py +++ b/senta/train.py @@ -278,7 +278,7 @@ def predict(self, texts_, aspects=None): else: for text, probs in zip(texts_, batch_result): label = self.label_map[np.argmax(probs)] - results.append((text, label)) + results.append((text, label,probs)) return results def train(self, json_path):