Skip to content

Commit

Permalink
Parquet/Arrow: implement faster spatial filtering with ArrowArray int…
Browse files Browse the repository at this point in the history
…erface

We no longer fallback to the slow & generic implementation that goes
through GetNextFeature(), but directly post filter the ArrowArray to
remove features not intersecting the spatial filter.

The performance gain is mostly when a big number of features is selected
(the fallback GetNextFeature() has already an efficient spatial
filtering, so when selecting a small number of features, this
optimization doesn't bring anything)
Can be up to 10x faster in that situation.

Now:
```

$ time bench_ogr_batch nz-building-outlines.parquet -spat 1167513 4794680 2089113 6190596

real	0m1,275s
user	0m1,565s
sys	0m0,322s
```

Before:
```
$ time bench_ogr_batch nz-building-outlines.parquet -spat 1167513 4794680 2089113 6190596

real	0m13,507s
user	0m13,728s
sys	0m0,712s

```
  • Loading branch information
rouault committed May 28, 2023
1 parent f5997fe commit ebc2f9e
Show file tree
Hide file tree
Showing 6 changed files with 778 additions and 53 deletions.
83 changes: 83 additions & 0 deletions autotest/ogr/ogr_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,89 @@ def test_ogr_parquet_arrow_stream_numpy():
)


###############################################################################


def test_ogr_parquet_arrow_stream_numpy_fast_spatial_filter():
pytest.importorskip("osgeo.gdal_array")
numpy = pytest.importorskip("numpy")
import datetime

ds = ogr.Open("data/parquet/test.parquet")
lyr = ds.GetLayer(0)
ignored_fields = ["decimal128", "decimal256", "time64_ns"]
lyr_defn = lyr.GetLayerDefn()
for i in range(lyr_defn.GetFieldCount()):
fld_defn = lyr_defn.GetFieldDefn(i)
if (
fld_defn.GetName().startswith("map_")
or fld_defn.GetName().startswith("struct_")
or fld_defn.GetName().startswith("fixed_size_")
or fld_defn.GetType()
not in (
ogr.OFTInteger,
ogr.OFTInteger64,
ogr.OFTReal,
ogr.OFTString,
ogr.OFTBinary,
ogr.OFTTime,
ogr.OFTDate,
ogr.OFTDateTime,
)
):
ignored_fields.append(fld_defn.GetNameRef())
lyr.SetIgnoredFields(ignored_fields)
lyr.SetSpatialFilterRect(-10, -10, 10, 10)
assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1

stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
fc = 0
for batch in stream:
fc += len(batch["uint8"])
assert fc == 4

lyr.SetSpatialFilterRect(3, 2, 3, 2)
assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1

stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
assert len(batches) == 1
batch = batches[0]
assert len(batch["geometry"]) == 1
assert batch["boolean"][0] == False
assert batch["uint8"][0] == 4
assert batch["int8"][0] == 1
assert batch["uint16"][0] == 30001
assert batch["int16"][0] == 10000
assert batch["uint32"][0] == 3000000001
assert batch["int32"][0] == 1000000000
assert batch["uint64"][0] == 300000000001
assert batch["int64"][0] == 100000000000
assert batch["int64"][0] == 100000000000
assert batch["float32"][0] == 4.5
assert batch["float64"][0] == 4.5
assert batch["string"][0] == b"c"
assert batch["large_string"][0] == b"c"
assert batch["timestamp_ms_gmt"][0] == numpy.datetime64("2019-01-01T14:00:00.000")
assert batch["time32_s"][0] == datetime.time(0, 0, 4)
assert batch["time32_ms"][0] == datetime.time(0, 0, 0, 4000)
assert batch["time64_us"][0] == datetime.time(0, 0, 0, 4)
assert batch["date32"][0] == numpy.datetime64("1970-01-05")
assert batch["date64"][0] == numpy.datetime64("1970-01-01")
assert bytes(batch["binary"][0]) == b"\00\01"
assert bytes(batch["large_binary"][0]) == b"\00\01"
assert (
ogr.CreateGeometryFromWkb(batch["geometry"][0]).ExportToWkt() == "POINT (3 2)"
)

lyr.SetSpatialFilterRect(1, 1, 1, 1)
assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1

stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
assert len(batches) == 0


###############################################################################
# Test bbox

Expand Down
9 changes: 9 additions & 0 deletions ogr/ogrsf_frmts/arrow_common/ogr_arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class OGRArrowLayer CPL_NON_FINAL
std::vector<Constraint> m_asAttributeFilterConstraints{};
int m_nUseOptimizedAttributeFilter = -1;
bool m_bSpatialFilterIntersectsLayerExtent = true;
bool m_bUseRecordBatchBaseImplementation = false;

// Modified by UseRecordBatchBaseImplementation()
mutable struct ArrowSchema m_sCachedSchema = {};

bool SkipToNextFeatureDueToAttributeFilter() const;
void ExploreExprNode(const swq_expr_node *poNode);
Expand All @@ -93,6 +97,8 @@ class OGRArrowLayer CPL_NON_FINAL
static struct ArrowArray *
CreateWKTArrayFromWKBArray(const struct ArrowArray *sourceArray);

int GetArrowSchemaInternal(struct ArrowSchema *out) const;

protected:
OGRArrowDataset *m_poArrowDS = nullptr;
arrow::MemoryPool *m_poMemoryPool = nullptr;
Expand Down Expand Up @@ -214,6 +220,9 @@ class OGRArrowLayer CPL_NON_FINAL

int TestCapability(const char *pszCap) override;

bool GetArrowStream(struct ArrowArrayStream *out_stream,
CSLConstList papszOptions = nullptr) override;

virtual std::unique_ptr<OGRFieldDomain>
BuildDomain(const std::string &osDomainName, int iFieldIndex) const = 0;

Expand Down
163 changes: 112 additions & 51 deletions ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ inline OGRArrowLayer::OGRArrowLayer(OGRArrowDataset *poDS,

inline OGRArrowLayer::~OGRArrowLayer()
{
if (m_sCachedSchema.release)
m_sCachedSchema.release(&m_sCachedSchema);

CPLDebug("ARROW", "Memory pool: bytes_allocated = %" PRId64,
m_poMemoryPool->bytes_allocated());
CPLDebug("ARROW", "Memory pool: max_memory = %" PRId64,
Expand Down Expand Up @@ -3237,7 +3240,7 @@ static void OverrideArrowRelease(OGRArrowDataset *poDS, T *obj)

inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const
{
if (m_poAttrQuery != nullptr || m_poFilterGeom != nullptr ||
if (m_poAttrQuery != nullptr ||
CPLTestBool(CPLGetConfigOption("OGR_ARROW_STREAM_BASE_IMPL", "NO")))
{
return true;
Expand All @@ -3254,6 +3257,8 @@ inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const
m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKB &&
m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKT)
{
CPLDebug("ARROW", "Geometry encoding not compatible of fast "
"Arrow implementation");
return true;
}
}
Expand All @@ -3279,25 +3284,64 @@ inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const
// struct fields will point to the same arrow column
if (ignoredState[nArrowCol] != static_cast<int>(bIsIgnored))
{
CPLDebug("ARROW",
"Inconsistent ignore state for Arrow Columns");
return true;
}
}
}
}

if (m_poFilterGeom)
{
struct ArrowSchema *psSchema = &m_sCachedSchema;
if (psSchema->release)
psSchema->release(psSchema);
memset(psSchema, 0, sizeof(*psSchema));

const bool bCanPostFilter = GetArrowSchemaInternal(psSchema) == 0 &&
CanPostFilterArrowArray(psSchema);
if (!bCanPostFilter)
return true;
}

return false;
}

/************************************************************************/
/* GetArrowStream() */
/************************************************************************/

inline bool OGRArrowLayer::GetArrowStream(struct ArrowArrayStream *out_stream,
CSLConstList papszOptions)
{
if (!OGRLayer::GetArrowStream(out_stream, papszOptions))
return false;

m_bUseRecordBatchBaseImplementation = UseRecordBatchBaseImplementation();
return true;
}

/************************************************************************/
/* GetArrowSchema() */
/************************************************************************/

inline int OGRArrowLayer::GetArrowSchema(struct ArrowArrayStream *stream,
struct ArrowSchema *out_schema)
{
if (UseRecordBatchBaseImplementation())
if (m_bUseRecordBatchBaseImplementation)
return OGRLayer::GetArrowSchema(stream, out_schema);

return GetArrowSchemaInternal(out_schema);
}

/************************************************************************/
/* GetArrowSchemaInternal() */
/************************************************************************/

inline int
OGRArrowLayer::GetArrowSchemaInternal(struct ArrowSchema *out_schema) const
{
auto status = arrow::ExportSchema(*m_poSchema, out_schema);
if (!status.ok())
{
Expand Down Expand Up @@ -3420,76 +3464,93 @@ inline int OGRArrowLayer::GetArrowSchema(struct ArrowArrayStream *stream,
inline int OGRArrowLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
struct ArrowArray *out_array)
{
if (UseRecordBatchBaseImplementation())
if (m_bUseRecordBatchBaseImplementation)
return OGRLayer::GetNextArrowArray(stream, out_array);

if (m_bEOF)
{
memset(out_array, 0, sizeof(*out_array));
return 0;
}

if (m_poBatch == nullptr || m_nIdxInBatch == m_poBatch->num_rows())
while (true)
{
m_bEOF = !ReadNextBatch();
if (m_bEOF)
{
memset(out_array, 0, sizeof(*out_array));
return 0;
}
}

auto status = arrow::ExportRecordBatch(*m_poBatch, out_array, nullptr);
m_nIdxInBatch = m_poBatch->num_rows();
if (!status.ok())
{
CPLError(CE_Failure, CPLE_AppDefined,
"ExportRecordBatch() failed with %s",
status.message().c_str());
return EIO;
}
if (m_poBatch == nullptr || m_nIdxInBatch == m_poBatch->num_rows())
{
m_bEOF = !ReadNextBatch();
if (m_bEOF)
{
memset(out_array, 0, sizeof(*out_array));
return 0;
}
}

if (EQUAL(m_aosArrowArrayStreamOptions.FetchNameValueDef(
"GEOMETRY_ENCODING", ""),
"WKB"))
{
const int nGeomFieldCount = m_poFeatureDefn->GetGeomFieldCount();
for (int i = 0; i < nGeomFieldCount; i++)
auto status = arrow::ExportRecordBatch(*m_poBatch, out_array, nullptr);
m_nIdxInBatch = m_poBatch->num_rows();
if (!status.ok())
{
const auto poGeomFieldDefn = m_poFeatureDefn->GetGeomFieldDefn(i);
if (!poGeomFieldDefn->IsIgnored())
CPLError(CE_Failure, CPLE_AppDefined,
"ExportRecordBatch() failed with %s",
status.message().c_str());
return EIO;
}

if (EQUAL(m_aosArrowArrayStreamOptions.FetchNameValueDef(
"GEOMETRY_ENCODING", ""),
"WKB"))
{
const int nGeomFieldCount = m_poFeatureDefn->GetGeomFieldCount();
for (int i = 0; i < nGeomFieldCount; i++)
{
if (m_aeGeomEncoding[i] == OGRArrowGeomEncoding::WKT)
const auto poGeomFieldDefn =
m_poFeatureDefn->GetGeomFieldDefn(i);
if (!poGeomFieldDefn->IsIgnored())
{
const int nArrayIdx =
m_bIgnoredFields
? m_anMapGeomFieldIndexToArrayIndex[i]
: m_anMapGeomFieldIndexToArrowColumn[i];
auto sourceArray = out_array->children[nArrayIdx];
auto targetArray = CreateWKTArrayFromWKBArray(sourceArray);
if (targetArray)
if (m_aeGeomEncoding[i] == OGRArrowGeomEncoding::WKT)
{
sourceArray->release(sourceArray);
out_array->children[nArrayIdx] = targetArray;
const int nArrayIdx =
m_bIgnoredFields
? m_anMapGeomFieldIndexToArrayIndex[i]
: m_anMapGeomFieldIndexToArrowColumn[i];
auto sourceArray = out_array->children[nArrayIdx];
auto targetArray =
CreateWKTArrayFromWKBArray(sourceArray);
if (targetArray)
{
sourceArray->release(sourceArray);
out_array->children[nArrayIdx] = targetArray;
}
else
{
out_array->release(out_array);
memset(out_array, 0, sizeof(*out_array));
return ENOMEM;
}
}
else
else if (m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKB)
{
out_array->release(out_array);
memset(out_array, 0, sizeof(*out_array));
return ENOMEM;
// Shouldn't happen if UseRecordBatchBaseImplementation()
// is up to date
CPLAssert(false);
}
}
else if (m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKB)
{
// Shouldn't happen if UseRecordBatchBaseImplementation()
// is up to date
CPLAssert(false);
}
}
}
}

OverrideArrowRelease(m_poArrowDS, out_array);
OverrideArrowRelease(m_poArrowDS, out_array);

if (m_poFilterGeom)
{
PostFilterArrowArray(&m_sCachedSchema, out_array);
if (out_array->length == 0)
{
// If there are no records after filtering, start again
// with a new batch
continue;
}
}
break;
}

return 0;
}
Expand Down
Loading

0 comments on commit ebc2f9e

Please sign in to comment.