@@ -701,107 +701,52 @@ def test_predict_tile_with_crop_model_empty():
701
701
# Assert the result
702
702
assert result is None
703
703
704
- # @pytest.mark.parametrize("batch_size", [1, 4, 8])
705
- # def test_batch_prediction(m, batch_size, raster_path):
706
- #
707
- # # Prepare input data
708
- # tile = np.array(Image.open(raster_path))
709
- # ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
710
- # dl = DataLoader(ds, batch_size=batch_size)
711
-
712
- # # Perform prediction
713
- # predictions = []
714
- # for batch in dl:
715
- # prediction = m.predict_batch(batch)
716
- # predictions.append(prediction)
717
-
718
- # # Check results
719
- # assert len(predictions) == len(dl)
720
- # for batch_pred in predictions:
721
- # assert isinstance(batch_pred, pd.DataFrame)
722
- # assert set(batch_pred.columns) == {
723
- # "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
724
- # }
725
-
726
- # @pytest.mark.parametrize("batch_size", [1, 4])
727
- # def test_batch_training(m, batch_size, tmpdir):
728
- #
729
- # # Generate synthetic training data
730
- # csv_file = get_data("example.csv")
731
- # root_dir = os.path.dirname(csv_file)
732
- # train_ds = m.load_dataset(csv_file, root_dir=root_dir)
733
- # train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
734
-
735
- # # Configure the model and trainer
736
- # m.config["batch_size"] = batch_size
737
- # m.create_trainer()
738
- # trainer = m.trainer
739
-
740
- # # Train the model
741
- # trainer.fit(m, train_dl)
742
-
743
- # # Assertions
744
- # assert trainer.current_epoch == 1
745
- # assert trainer.batch_size == batch_size
746
-
747
- # @pytest.mark.parametrize("batch_size", [2, 4])
748
- # def test_batch_data_augmentation(m, batch_size, raster_path):
749
- #
750
- # tile = np.array(Image.open(raster_path))
751
- # ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
752
- # dl = DataLoader(ds, batch_size=batch_size)
704
+ def test_batch_prediction (m , raster_path ):
705
+ # Prepare input data
706
+ tile = np .array (Image .open (raster_path ))
707
+ ds = dataset .TileDataset (tile = tile , patch_overlap = 0.1 , patch_size = 300 )
708
+ dl = DataLoader (ds , batch_size = 3 )
753
709
754
- # predictions = []
755
- # for batch in dl:
756
- # prediction = m.predict_batch(batch)
757
- # predictions.append(prediction)
710
+ # Perform prediction
711
+ predictions = []
712
+ for batch in dl :
713
+ prediction = m .predict_batch (batch )
714
+ predictions .append (prediction )
758
715
759
- # assert len(predictions) == len(dl)
760
- # for batch_pred in predictions:
761
- # assert isinstance(batch_pred, pd.DataFrame)
762
- # assert set(batch_pred.columns) == {
763
- # "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
764
- # }
765
-
766
- # def test_batch_inference_consistency(m, raster_path):
767
- #
768
- # tile = np.array(Image.open(raster_path))
769
- # ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
770
- # dl = DataLoader(ds, batch_size=4)
716
+ # Check results
717
+ assert len (predictions ) == len (dl )
718
+ for batch_pred in predictions :
719
+ for image_pred in batch_pred :
720
+ assert isinstance (image_pred , pd .DataFrame )
721
+ assert "label" in image_pred .columns
722
+ assert "score" in image_pred .columns
723
+ assert "geometry" in image_pred .columns
724
+
725
+ def test_batch_inference_consistency (m , raster_path ):
726
+ tile = np .array (Image .open (raster_path ))
727
+ ds = dataset .TileDataset (tile = tile , patch_overlap = 0.1 , patch_size = 300 )
728
+ dl = DataLoader (ds , batch_size = 4 )
771
729
772
- # batch_predictions = []
773
- # for batch in dl:
774
- # prediction = m.predict_batch(batch)
775
- # batch_predictions.append (prediction)
730
+ batch_predictions = []
731
+ for batch in dl :
732
+ prediction = m .predict_batch (batch )
733
+ batch_predictions .extend (prediction )
776
734
777
- # single_predictions = []
778
- # for image in ds:
779
- # prediction = m.predict_image(image=image)
780
- # single_predictions.append(prediction)
735
+ single_predictions = []
736
+ for image in ds :
737
+ image = image .permute (1 ,2 ,0 ).numpy () * 255
738
+ prediction = m .predict_image (image = image )
739
+ single_predictions .append (prediction )
781
740
782
- # batch_df = pd.concat(batch_predictions, ignore_index=True)
783
- # single_df = pd.concat(single_predictions, ignore_index=True)
741
+ batch_df = pd .concat (batch_predictions , ignore_index = True )
742
+ single_df = pd .concat (single_predictions , ignore_index = True )
784
743
785
- # pd.testing.assert_frame_equal(batch_df, single_df)
744
+ # Make all xmin, ymin, xmax, ymax integers
745
+ for col in ["xmin" , "ymin" , "xmax" , "ymax" ]:
746
+ batch_df [col ] = batch_df [col ].astype (int )
747
+ single_df [col ] = single_df [col ].astype (int )
748
+ pd .testing .assert_frame_equal (batch_df [["xmin" , "ymin" , "xmax" , "ymax" ]], single_df [["xmin" , "ymin" , "xmax" , "ymax" ]], check_dtype = False )
786
749
787
- # def test_large_batch_handling(m, raster_path):
788
- #
789
- # tile = np.array(Image.open(raster_path))
790
- # ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
791
- # dl = DataLoader(ds, batch_size=16)
792
-
793
- # predictions = []
794
- # for batch in dl:
795
- # prediction = m.predict_batch(batch)
796
- # predictions.append(prediction)
797
-
798
- # assert len(predictions) > 0
799
- # for batch_pred in predictions:
800
- # assert isinstance(batch_pred, pd.DataFrame)
801
- # assert set(batch_pred.columns) == {
802
- # "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
803
- # }
804
- # assert not batch_pred.empty
805
750
806
751
def test_epoch_evaluation_end (m ):
807
752
preds = [{
0 commit comments