Skip to content

Commit 2bf8a9a

Browse files
authored
Improve predict batch (#876)
* run per batch, not per image per batch
1 parent ef4435c commit 2bf8a9a

File tree

2 files changed

+43
-99
lines changed

2 files changed

+43
-99
lines changed

src/deepforest/main.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -881,12 +881,11 @@ def predict_batch(self, images, preprocess_fn=None):
881881

882882
#using Pytorch Ligthning's predict_step
883883
with torch.no_grad():
884-
predictions = []
885-
for idx, image in enumerate(images):
886-
predictions = self.predict_step(image.unsqueeze(0), idx)
887-
predictions.extend(predictions)
884+
predictions = self.predict_step(images, 0)
885+
888886
#convert predictions to dataframes
889-
results = [pd.DataFrame(pred) for pred in predictions if pred is not None]
887+
results = [utilities.read_file(pred) for pred in predictions if pred is not None]
888+
890889
return results
891890

892891
def configure_optimizers(self):

tests/test_main.py

+39-94
Original file line numberDiff line numberDiff line change
@@ -701,107 +701,52 @@ def test_predict_tile_with_crop_model_empty():
701701
# Assert the result
702702
assert result is None
703703

704-
# @pytest.mark.parametrize("batch_size", [1, 4, 8])
705-
# def test_batch_prediction(m, batch_size, raster_path):
706-
#
707-
# # Prepare input data
708-
# tile = np.array(Image.open(raster_path))
709-
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
710-
# dl = DataLoader(ds, batch_size=batch_size)
711-
712-
# # Perform prediction
713-
# predictions = []
714-
# for batch in dl:
715-
# prediction = m.predict_batch(batch)
716-
# predictions.append(prediction)
717-
718-
# # Check results
719-
# assert len(predictions) == len(dl)
720-
# for batch_pred in predictions:
721-
# assert isinstance(batch_pred, pd.DataFrame)
722-
# assert set(batch_pred.columns) == {
723-
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
724-
# }
725-
726-
# @pytest.mark.parametrize("batch_size", [1, 4])
727-
# def test_batch_training(m, batch_size, tmpdir):
728-
#
729-
# # Generate synthetic training data
730-
# csv_file = get_data("example.csv")
731-
# root_dir = os.path.dirname(csv_file)
732-
# train_ds = m.load_dataset(csv_file, root_dir=root_dir)
733-
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
734-
735-
# # Configure the model and trainer
736-
# m.config["batch_size"] = batch_size
737-
# m.create_trainer()
738-
# trainer = m.trainer
739-
740-
# # Train the model
741-
# trainer.fit(m, train_dl)
742-
743-
# # Assertions
744-
# assert trainer.current_epoch == 1
745-
# assert trainer.batch_size == batch_size
746-
747-
# @pytest.mark.parametrize("batch_size", [2, 4])
748-
# def test_batch_data_augmentation(m, batch_size, raster_path):
749-
#
750-
# tile = np.array(Image.open(raster_path))
751-
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
752-
# dl = DataLoader(ds, batch_size=batch_size)
704+
def test_batch_prediction(m, raster_path):
705+
# Prepare input data
706+
tile = np.array(Image.open(raster_path))
707+
ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300)
708+
dl = DataLoader(ds, batch_size=3)
753709

754-
# predictions = []
755-
# for batch in dl:
756-
# prediction = m.predict_batch(batch)
757-
# predictions.append(prediction)
710+
# Perform prediction
711+
predictions = []
712+
for batch in dl:
713+
prediction = m.predict_batch(batch)
714+
predictions.append(prediction)
758715

759-
# assert len(predictions) == len(dl)
760-
# for batch_pred in predictions:
761-
# assert isinstance(batch_pred, pd.DataFrame)
762-
# assert set(batch_pred.columns) == {
763-
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
764-
# }
765-
766-
# def test_batch_inference_consistency(m, raster_path):
767-
#
768-
# tile = np.array(Image.open(raster_path))
769-
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
770-
# dl = DataLoader(ds, batch_size=4)
716+
# Check results
717+
assert len(predictions) == len(dl)
718+
for batch_pred in predictions:
719+
for image_pred in batch_pred:
720+
assert isinstance(image_pred, pd.DataFrame)
721+
assert "label" in image_pred.columns
722+
assert "score" in image_pred.columns
723+
assert "geometry" in image_pred.columns
724+
725+
def test_batch_inference_consistency(m, raster_path):
726+
tile = np.array(Image.open(raster_path))
727+
ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300)
728+
dl = DataLoader(ds, batch_size=4)
771729

772-
# batch_predictions = []
773-
# for batch in dl:
774-
# prediction = m.predict_batch(batch)
775-
# batch_predictions.append(prediction)
730+
batch_predictions = []
731+
for batch in dl:
732+
prediction = m.predict_batch(batch)
733+
batch_predictions.extend(prediction)
776734

777-
# single_predictions = []
778-
# for image in ds:
779-
# prediction = m.predict_image(image=image)
780-
# single_predictions.append(prediction)
735+
single_predictions = []
736+
for image in ds:
737+
image = image.permute(1,2,0).numpy() * 255
738+
prediction = m.predict_image(image=image)
739+
single_predictions.append(prediction)
781740

782-
# batch_df = pd.concat(batch_predictions, ignore_index=True)
783-
# single_df = pd.concat(single_predictions, ignore_index=True)
741+
batch_df = pd.concat(batch_predictions, ignore_index=True)
742+
single_df = pd.concat(single_predictions, ignore_index=True)
784743

785-
# pd.testing.assert_frame_equal(batch_df, single_df)
744+
# Make all xmin, ymin, xmax, ymax integers
745+
for col in ["xmin", "ymin", "xmax", "ymax"]:
746+
batch_df[col] = batch_df[col].astype(int)
747+
single_df[col] = single_df[col].astype(int)
748+
pd.testing.assert_frame_equal(batch_df[["xmin", "ymin", "xmax", "ymax"]], single_df[["xmin", "ymin", "xmax", "ymax"]], check_dtype=False)
786749

787-
# def test_large_batch_handling(m, raster_path):
788-
#
789-
# tile = np.array(Image.open(raster_path))
790-
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
791-
# dl = DataLoader(ds, batch_size=16)
792-
793-
# predictions = []
794-
# for batch in dl:
795-
# prediction = m.predict_batch(batch)
796-
# predictions.append(prediction)
797-
798-
# assert len(predictions) > 0
799-
# for batch_pred in predictions:
800-
# assert isinstance(batch_pred, pd.DataFrame)
801-
# assert set(batch_pred.columns) == {
802-
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
803-
# }
804-
# assert not batch_pred.empty
805750

806751
def test_epoch_evaluation_end(m):
807752
preds = [{

0 commit comments

Comments
 (0)