diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c35e7e7c..a7f666544 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - [ENH] Added a `clean_names` method for polars - it can be used to clean the column names, or clean column values . Issue #1343 @samukweku - [ENH] Improved performance for non-equi joins when using numba - @samukweku PR #1341 - [ENH] pandas Index,Series, DataFrame now supported in the `complete` method. - PR #1369 @samukweku +- [ENH] Improve performance for `first/last` in `conditional_join, when the join columns in the right dataframe are sorted. - PR #1382 @samukweku ## [v0.27.0] - 2024-03-21 diff --git a/janitor/functions/conditional_join.py b/janitor/functions/conditional_join.py index f9aaa0891..ebdd36384 100644 --- a/janitor/functions/conditional_join.py +++ b/janitor/functions/conditional_join.py @@ -508,7 +508,10 @@ def _conditional_join_compute( for condition in conditions: left_on, right_on, op = condition _conditional_join_type_check( - df[left_on], right[right_on], op, use_numba + left_column=df[left_on], + right_column=right[right_on], + op=op, + use_numba=use_numba, ) if op == _JoinOperator.STRICTLY_EQUAL.value: eq_check = True @@ -520,19 +523,25 @@ def _conditional_join_compute( if (len(conditions) > 1) or eq_check: if eq_check: result = _multiple_conditional_join_eq( - df, - right, - conditions, - keep, - use_numba, - force, + df=df, + right=right, + conditions=conditions, + keep=keep, + use_numba=use_numba, + force=force, ) elif le_lt_check: result = _multiple_conditional_join_le_lt( - df, right, conditions, keep, use_numba + df=df, + right=right, + conditions=conditions, + keep=keep, + use_numba=use_numba, ) else: - result = _multiple_conditional_join_ne(df, right, conditions, keep) + result = _multiple_conditional_join_ne( + df=df, right=right, conditions=conditions, keep=keep + ) else: left_on, right_on, op = conditions[0] if use_numba: @@ -558,14 +567,16 @@ def _conditional_join_compute( if return_matching_indices: return result + left_index, right_index = result return _create_frame( - df, - right, - *result, - how, - df_columns, - right_columns, - indicator, + df=df, + right=right, + left_index=left_index, + right_index=right_index, + how=how, + df_columns=df_columns, + right_columns=right_columns, + indicator=indicator, ) @@ -630,9 +641,9 @@ def _multiple_conditional_join_ne( left_on, right_on, op = first indices = _generic_func_cond_join( - df[left_on], - right[right_on], - op, + left=df[left_on], + right=right[right_on], + op=op, multiple_conditions=False, keep="all", ) @@ -755,7 +766,9 @@ def _multiple_conditional_join_eq( ) if not right_is_sorted: right_df = right_df.sort_values(right_columns) - indices = _numba_equi_join(left_df, right_df, eqs, ge_gt, le_lt) + indices = _numba_equi_join( + df=left_df, right=right_df, eqs=eqs, ge_gt=ge_gt, le_lt=le_lt + ) if not rest or (indices is None): return indices @@ -909,7 +922,25 @@ def _multiple_conditional_join_le_lt( if condition not in (ge_gt, le_lt) ] - indices = _range_indices(df, right, ge_gt, le_lt) + if conditions: + _keep = None + else: + first = ge_gt[1] + second = le_lt[1] + right_is_sorted = ( + right[first].is_monotonic_increasing + & right[second].is_monotonic_increasing + ) + if right_is_sorted: + _keep = keep + else: + _keep = None + + indices = _range_indices( + df=df, right=right, first=ge_gt, second=le_lt, keep=_keep + ) + if _keep: + return indices # no optimised path # blow up the rows and prune @@ -926,9 +957,9 @@ def _multiple_conditional_join_le_lt( left_on, right_on, op = ge_gt indices = _generic_func_cond_join( - df[left_on], - right[right_on], - op, + left=df[left_on], + right=right[right_on], + op=op, multiple_conditions=False, keep="all", ) @@ -951,6 +982,7 @@ def _range_indices( right: pd.DataFrame, first: tuple, second: tuple, + keep: str, ) -> Union[tuple[np.ndarray, np.ndarray], None]: """ Retrieve index positions for range/interval joins. @@ -1019,8 +1051,6 @@ def _range_indices( # this is solved by getting the cumulative max # thus ensuring that the first match is obtained # via a binary search - # this allows us to avoid the less efficient linear search - # of using a for loop with a break to get the first match outcome = _generic_func_cond_join( left=left_c, right=right_c.cummax(), @@ -1055,7 +1085,10 @@ def _range_indices( # this also implies that the intervals # do not overlap on the right side return left_index, right_index[starts] - + if keep == "first": + return left_index, right_index[starts] + if keep == "last": + return left_index, right_index[ends - 1] right_index = [right_index[start:end] for start, end in zip(starts, ends)] right_index = np.concatenate(right_index) left_index = left_index.repeat(repeater) diff --git a/janitor/functions/utils.py b/janitor/functions/utils.py index 89ae1349b..358b47ce4 100644 --- a/janitor/functions/utils.py +++ b/janitor/functions/utils.py @@ -813,6 +813,10 @@ def _less_than_indices( if multiple_conditions: return left_index, right_index, search_indices + if right_is_sorted and (keep == "last"): + indexer = np.empty_like(search_indices) + indexer[:] = len_right - 1 + return left_index, right_index[indexer] if right_is_sorted and (keep == "first"): if any_nulls: return left_index, right_index[search_indices] @@ -902,6 +906,9 @@ def _greater_than_indices( if multiple_conditions: return left_index, right_index, search_indices + if right_is_sorted and (keep == "first"): + indexer = np.zeros_like(search_indices) + return left_index, right_index[indexer] if right_is_sorted and (keep == "last"): if any_nulls: return left_index, right_index[search_indices - 1] @@ -1043,9 +1050,9 @@ def _keep_output(keep: str, left: np.ndarray, right: np.ndarray): grouped = pd.Series(right).groupby(left) if keep == "first": grouped = grouped.min() - return grouped.index, grouped.array + return grouped.index, grouped._values grouped = grouped.max() - return grouped.index, grouped.array + return grouped.index, grouped._values class col: