diff --git a/janitor/polars/complete.py b/janitor/polars/complete.py index ef098ede7..546f903bc 100644 --- a/janitor/polars/complete.py +++ b/janitor/polars/complete.py @@ -385,14 +385,14 @@ def _complete( no_columns_to_fill = set(df.columns) == set(uniques.columns) if fill_value is None or no_columns_to_fill: - return uniques.join(df, on=uniques.columns, how="full", coalesce=True) + return uniques.join(df, on=uniques.columns, how="left", coalesce=True) idx = None columns_to_select = df.columns if not explicit: idx = "".join(df.columns) idx = f"{idx}_" df = df.with_row_index(name=idx) - df = uniques.join(df, on=uniques.columns, how="full", coalesce=True) + df = uniques.join(df, on=uniques.columns, how="left", coalesce=True) # exclude columns that were not used # to generate the combinations exclude_columns = uniques.columns diff --git a/janitor/polars/pivot_longer.py b/janitor/polars/pivot_longer.py index ff11fbc44..9dea2581f 100644 --- a/janitor/polars/pivot_longer.py +++ b/janitor/polars/pivot_longer.py @@ -2,16 +2,12 @@ from __future__ import annotations -from collections import defaultdict -from typing import Any, Iterable - from janitor.utils import check, import_message from .polars_flavor import register_dataframe_method, register_lazyframe_method try: import polars as pl - import polars.selectors as cs from polars.type_aliases import ColumnNameOrSelector except ImportError: import_message( @@ -37,14 +33,14 @@ def pivot_longer_spec( becomes variables. It can come in handy for situations where - `janitor.polars.pivot_longer` + [`pivot_longer`][janitor.polars.pivot_longer.pivot_longer] seems inadequate for the transformation. !!! info "New in version 0.28.0" Examples: >>> import pandas as pd - >>> import janitor.polars + >>> from janitor.polars import pivot_longer_spec >>> df = pl.DataFrame( ... { ... "Sepal.Length": [5.1, 5.9], @@ -81,18 +77,18 @@ def pivot_longer_spec( │ Sepal.Width ┆ Width ┆ Sepal │ │ Petal.Width ┆ Width ┆ Petal │ └──────────────┴────────┴───────┘ - >>> df.pipe(pivot_longer_spec,spec=spec) + >>> df.pipe(pivot_longer_spec,spec=spec).sort(by=pl.all()) shape: (4, 4) - ┌───────────┬────────┬───────┬───────┐ - │ Species ┆ Length ┆ Width ┆ part │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ f64 ┆ f64 ┆ str │ - ╞═══════════╪════════╪═══════╪═══════╡ - │ setosa ┆ 5.1 ┆ 3.5 ┆ Sepal │ - │ virginica ┆ 5.9 ┆ 3.0 ┆ Sepal │ - │ setosa ┆ 1.4 ┆ 0.2 ┆ Petal │ - │ virginica ┆ 5.1 ┆ 1.8 ┆ Petal │ - └───────────┴────────┴───────┴───────┘ + ┌───────────┬───────┬────────┬───────┐ + │ Species ┆ part ┆ Length ┆ Width │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ f64 ┆ f64 │ + ╞═══════════╪═══════╪════════╪═══════╡ + │ setosa ┆ Petal ┆ 1.4 ┆ 0.2 │ + │ setosa ┆ Sepal ┆ 5.1 ┆ 3.5 │ + │ virginica ┆ Petal ┆ 5.1 ┆ 1.8 │ + │ virginica ┆ Sepal ┆ 5.9 ┆ 3.0 │ + └───────────┴───────┴────────┴───────┘ Args: df: The source DataFrame to unpivot. @@ -140,17 +136,29 @@ def pivot_longer_spec( "Kindly ensure the spec DataFrame's columns " "are not present in the source DataFrame." ) - - if spec.columns[:2] != [".name", ".value"]: - raise ValueError( - "The first two columns of the spec DataFrame " - "should be '.name' and '.value', " - "with '.name' coming before '.value'." - ) - + index = [ + label for label in df.columns if label not in spec.get_column(".name") + ] + others = [ + label for label in spec.columns if label not in {".name", ".value"} + ] + variable_name = "".join(df.columns + spec.columns) + variable_name = f"{variable_name}_" + if others: + dot_value_only = False + expression = pl.struct(others).alias(variable_name) + spec = spec.select(".name", ".value", expression) + else: + dot_value_only = True + expression = pl.cum_count(".value").over(".value").alias(variable_name) + spec = spec.with_columns(expression) return _pivot_longer_dot_value( df=df, + index=index, spec=spec, + variable_name=variable_name, + dot_value_only=dot_value_only, + names_transform=None, ) @@ -179,8 +187,11 @@ def pivot_longer( All measured variables are *unpivoted* (and typically duplicated) along the row axis. + If `names_pattern`, use a valid regular expression pattern containing at least + one capture group, compatible with the [regex crate](https://docs.rs/regex/latest/regex/). + For more granular control on the unpivoting, have a look at - `pivot_longer_spec`. + [`pivot_longer_spec`][janitor.polars.pivot_longer.pivot_longer_spec]. `pivot_longer` can also be applied to a LazyFrame. @@ -209,21 +220,21 @@ def pivot_longer( └──────────────┴─────────────┴──────────────┴─────────────┴───────────┘ Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html#polars-dataframe-melt): - >>> df.pivot_longer(index = 'Species') + >>> df.pivot_longer(index = 'Species').sort(by=pl.all()) shape: (8, 3) ┌───────────┬──────────────┬───────┐ │ Species ┆ variable ┆ value │ │ --- ┆ --- ┆ --- │ │ str ┆ str ┆ f64 │ ╞═══════════╪══════════════╪═══════╡ + │ setosa ┆ Petal.Length ┆ 1.4 │ + │ setosa ┆ Petal.Width ┆ 0.2 │ │ setosa ┆ Sepal.Length ┆ 5.1 │ - │ virginica ┆ Sepal.Length ┆ 5.9 │ │ setosa ┆ Sepal.Width ┆ 3.5 │ - │ virginica ┆ Sepal.Width ┆ 3.0 │ - │ setosa ┆ Petal.Length ┆ 1.4 │ │ virginica ┆ Petal.Length ┆ 5.1 │ - │ setosa ┆ Petal.Width ┆ 0.2 │ │ virginica ┆ Petal.Width ┆ 1.8 │ + │ virginica ┆ Sepal.Length ┆ 5.9 │ + │ virginica ┆ Sepal.Width ┆ 3.0 │ └───────────┴──────────────┴───────┘ Split the column labels into individual columns: @@ -231,21 +242,21 @@ def pivot_longer( ... index = 'Species', ... names_to = ('part', 'dimension'), ... names_sep = '.', - ... ).select('Species','part','dimension','value') + ... ).select('Species','part','dimension','value').sort(by=pl.all()) shape: (8, 4) ┌───────────┬───────┬───────────┬───────┐ │ Species ┆ part ┆ dimension ┆ value │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 │ ╞═══════════╪═══════╪═══════════╪═══════╡ + │ setosa ┆ Petal ┆ Length ┆ 1.4 │ + │ setosa ┆ Petal ┆ Width ┆ 0.2 │ │ setosa ┆ Sepal ┆ Length ┆ 5.1 │ - │ virginica ┆ Sepal ┆ Length ┆ 5.9 │ │ setosa ┆ Sepal ┆ Width ┆ 3.5 │ - │ virginica ┆ Sepal ┆ Width ┆ 3.0 │ - │ setosa ┆ Petal ┆ Length ┆ 1.4 │ │ virginica ┆ Petal ┆ Length ┆ 5.1 │ - │ setosa ┆ Petal ┆ Width ┆ 0.2 │ │ virginica ┆ Petal ┆ Width ┆ 1.8 │ + │ virginica ┆ Sepal ┆ Length ┆ 5.9 │ + │ virginica ┆ Sepal ┆ Width ┆ 3.0 │ └───────────┴───────┴───────────┴───────┘ Retain parts of the column names as headers: @@ -253,17 +264,17 @@ def pivot_longer( ... index = 'Species', ... names_to = ('part', '.value'), ... names_sep = '.', - ... ).select('Species','part','Length','Width') + ... ).select('Species','part','Length','Width').sort(by=pl.all()) shape: (4, 4) ┌───────────┬───────┬────────┬───────┐ │ Species ┆ part ┆ Length ┆ Width │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ f64 ┆ f64 │ ╞═══════════╪═══════╪════════╪═══════╡ - │ setosa ┆ Sepal ┆ 5.1 ┆ 3.5 │ - │ virginica ┆ Sepal ┆ 5.9 ┆ 3.0 │ │ setosa ┆ Petal ┆ 1.4 ┆ 0.2 │ + │ setosa ┆ Sepal ┆ 5.1 ┆ 3.5 │ │ virginica ┆ Petal ┆ 5.1 ┆ 1.8 │ + │ virginica ┆ Sepal ┆ 5.9 ┆ 3.0 │ └───────────┴───────┴────────┴───────┘ Split the column labels based on regex: @@ -393,7 +404,7 @@ def _pivot_longer( df: pl.DataFrame | pl.LazyFrame, index: ColumnNameOrSelector, column_names: ColumnNameOrSelector, - names_to: list | tuple | str, + names_to: list | tuple | str | None, values_to: str, names_sep: str, names_pattern: str, @@ -403,6 +414,14 @@ def _pivot_longer( Unpivots a DataFrame/LazyFrame from wide to long form. """ + if all((names_pattern is None, names_sep is None)): + return df.melt( + id_vars=index, + value_vars=column_names, + variable_name=names_to, + value_name=values_to, + ) + ( df, index, @@ -411,7 +430,6 @@ def _pivot_longer( values_to, names_sep, names_pattern, - names_transform, ) = _data_checks_pivot_longer( df=df, index=index, @@ -420,43 +438,53 @@ def _pivot_longer( values_to=values_to, names_sep=names_sep, names_pattern=names_pattern, - names_transform=names_transform, ) - if not column_names: - return df - - if all((names_pattern is None, names_sep is None)): - return df.melt( - id_vars=index, - value_vars=column_names, - variable_name=names_to, - value_name=values_to, - ) - - df = df.select(pl.col(index), pl.col(column_names)) - if isinstance(names_to, str): - names_to = [names_to] - + variable_name = "".join(df.columns) + variable_name = f"{variable_name}_" spec = _pivot_longer_create_spec( column_names=column_names, names_to=names_to, names_sep=names_sep, names_pattern=names_pattern, - values_to=values_to, - names_transform=names_transform, + variable_name=variable_name, ) - return _pivot_longer_dot_value(df=df, spec=spec) + if ".value" not in names_to: + return _pivot_longer_no_dot_value( + df=df, + index=index, + spec=spec, + column_names=column_names, + names_to=names_to, + values_to=values_to, + variable_name=variable_name, + names_transform=names_transform, + ) + + if {".name", ".value"}.symmetric_difference(spec.columns): + dot_value_only = False + else: + dot_value_only = True + expression = pl.cum_count(".value").over(".value").alias(variable_name) + spec = spec.with_columns(expression) + + return _pivot_longer_dot_value( + df=df, + index=index, + spec=spec, + variable_name=variable_name, + dot_value_only=dot_value_only, + names_transform=names_transform, + ) def _pivot_longer_create_spec( - column_names: Iterable, - names_to: Iterable, + column_names: list, + names_to: list, names_sep: str | None, names_pattern: str | None, - values_to: str, - names_transform: pl.Expr, + variable_name: str, ) -> pl.DataFrame: """ This is where the spec DataFrame is created, @@ -468,16 +496,16 @@ def _pivot_longer_create_spec( pl.col(".name") .str.split(by=names_sep) .list.to_struct(n_field_strategy="max_width") - .alias("extract") + .alias(variable_name) ) else: expression = ( pl.col(".name") .str.extract_groups(pattern=names_pattern) - .alias("extract") + .alias(variable_name) ) spec = spec.with_columns(expression) - len_fields = len(spec.get_column("extract").struct.fields) + len_fields = len(spec.get_column(variable_name).struct.fields) len_names_to = len(names_to) if len_names_to != len_fields: @@ -492,7 +520,7 @@ def _pivot_longer_create_spec( expression = pl.exclude(".name").is_null().any() expression = pl.any_horizontal(expression) null_check = ( - spec.unnest(columns="extract") + spec.unnest(columns=variable_name) .filter(expression) .get_column(".name") ) @@ -504,112 +532,132 @@ def _pivot_longer_create_spec( "in the provided regex. Kindly provide a regular expression " "(with the correct groups) that matches all labels in the columns." ) - if names_to.count(".value") < 2: - expression = pl.col("extract").struct.rename_fields(names=names_to) - spec = spec.with_columns(expression).unnest(columns="extract") - else: - spec = _squash_multiple_dot_value(spec=spec, names_to=names_to) + if ".value" not in names_to: - expression = pl.lit(value=values_to).alias(".value") - spec = spec.with_columns(expression) + spec = spec.get_column(variable_name) + spec = spec.struct.rename_fields(names=names_to) + return spec + if names_to.count(".value") == 1: + spec = spec.with_columns( + pl.col(variable_name).struct.rename_fields(names=names_to) + ) + not_dot_value = [name for name in names_to if name != ".value"] + spec = spec.unnest(variable_name) + if not_dot_value: + return spec.select( + ".name", + ".value", + pl.struct(not_dot_value).alias(variable_name), + ) + return spec.select(".name", ".value") + _spec = spec.get_column(variable_name) + _spec = _spec.struct.unnest() + fields = _spec.columns + + if len(set(names_to)) == 1: + expression = pl.concat_str(fields).alias(".value") + dot_value = _spec.select(expression) + dot_value = dot_value.to_series(0) + return spec.select(".name", dot_value) + dot_value = [ + field for field, label in zip(fields, names_to) if label == ".value" + ] + dot_value = pl.concat_str(dot_value).alias(".value") + not_dot_value = [ + pl.col(field).alias(label) + for field, label in zip(fields, names_to) + if label != ".value" + ] + not_dot_value = pl.struct(not_dot_value).alias(variable_name) + return _spec.select(spec.get_column(".name"), not_dot_value, dot_value) + - spec = spec.select( - pl.col([".name", ".value"]), pl.exclude([".name", ".value"]) +def _pivot_longer_no_dot_value( + df: pl.DataFrame | pl.LazyFrame, + spec: pl.DataFrame, + index: ColumnNameOrSelector, + column_names: ColumnNameOrSelector, + names_to: list | tuple, + values_to: str, + variable_name: str, + names_transform: pl.Expr, +) -> pl.DataFrame | pl.LazyFrame: + """ + flip polars Frame to long form, + if no .value in names_to. + """ + # the implode/explode approach is used here + # for efficiency + # do the operation on a smaller size + # and then blow it up after + # it is usually much faster + # than running on the actual data + outcome = ( + df.select(pl.all().implode()) + .melt( + id_vars=index, + value_vars=column_names, + variable_name=variable_name, + value_name=values_to, + ) + .with_columns(spec) ) + + outcome = outcome.unnest(variable_name) if names_transform is not None: - spec = spec.with_columns(names_transform) - return spec + outcome = outcome.with_columns(names_transform) + columns = [name for name in outcome.columns if name not in names_to] + outcome = outcome.explode(columns=columns) + return outcome def _pivot_longer_dot_value( - df: pl.DataFrame | pl.LazyFrame, spec: pl.DataFrame + df: pl.DataFrame | pl.LazyFrame, + spec: pl.DataFrame, + index: ColumnNameOrSelector, + variable_name: str, + dot_value_only: bool, + names_transform: pl.Expr, ) -> pl.DataFrame | pl.LazyFrame: """ - Reshape DataFrame to long form based on metadata in `spec`. + flip polars Frame to long form, + if names_sep and .value in names_to. """ - index = [column for column in df.columns if column not in spec[".name"]] - not_dot_value = [ - column for column in spec.columns if column not in {".name", ".value"} - ] - idx = "".join(spec.columns) - if not_dot_value: - # assign a number to each group (grouped by not_dot_value) - expression = pl.first(idx).over(not_dot_value).rank("dense").sub(1) - spec = spec.with_row_index(name=idx).with_columns(expression) - else: - # use a cumulative count to properly pair the columns - # grouped by .value - expression = pl.cum_count(".value").over(".value").alias(idx) - spec = spec.with_columns(expression) - mapping = defaultdict(list) - for position, column_name, replacement_name in zip( - spec.get_column(name=idx), - spec.get_column(name=".name"), - spec.get_column(name=".value"), + spec = spec.group_by(variable_name) + spec = spec.agg(pl.all()) + expressions = [] + for names, fields in zip( + spec.get_column(".name").to_list(), + spec.get_column(".value").to_list(), ): - expression = pl.col(column_name).alias(replacement_name) - mapping[position].append(expression) - - mapping = ( - ( - [ - *index, - *columns_to_select, - ], - pl.lit(position, dtype=pl.UInt32).alias(idx), - ) - for position, columns_to_select in mapping.items() + expression = pl.struct(names).struct.rename_fields(names=fields) + expressions.append(expression) + expressions = [*index, *expressions] + spec = spec.get_column(variable_name) + outcome = ( + df.select(expressions) + .select(pl.all().implode()) + .melt(id_vars=index, variable_name=variable_name, value_name=".value") + .with_columns(spec) ) - df = [ - df.select(columns_to_select).with_columns(position) - for columns_to_select, position in mapping - ] - # rechunking can be expensive; - # however subsequent operations are faster - # since data is contiguous in memory - df = pl.concat(df, how="diagonal_relaxed", rechunk=True) - expression = pl.cum_count(".value").over(".value").eq(1) - dot_value = spec.filter(expression).select(".value") - columns_to_select = [*index, *dot_value.to_series(0)] - if not_dot_value: - if isinstance(df, pl.LazyFrame): - ranges = df.select(idx).collect().get_column(idx) - else: - ranges = df.get_column(idx) - spec = spec.select(pl.struct(not_dot_value)) - _value = spec.columns[0] - expression = pl.cum_count(_value).over(_value).eq(1) - # using a gather approach, instead of a join - # offers more performance - not sure why - # maybe in the join there is another rechunking? - spec = spec.filter(expression).select(pl.col(_value).gather(ranges)) - df = df.with_columns(spec).unnest(_value) - columns_to_select.extend(not_dot_value) - return df.select(columns_to_select) - - -def _squash_multiple_dot_value( - spec: pl.DataFrame, names_to: Iterable -) -> pl.DataFrame: - """ - Combine multiple .values into a single .value column - """ - extract = spec.get_column("extract") - fields = extract.struct.fields - dot_value = [ - field for field, label in zip(fields, names_to) if label == ".value" - ] - dot_value = pl.concat_str(dot_value).alias(".value") - not_dot_value = [ - pl.col(field).alias(label) - for field, label in zip(fields, names_to) - if label != ".value" + + if dot_value_only: + columns = [ + label for label in outcome.columns if label != variable_name + ] + outcome = outcome.explode(columns).unnest(".value") + outcome = outcome.select(pl.exclude(variable_name)) + return outcome + outcome = outcome.unnest(variable_name) + if names_transform is not None: + outcome = outcome.with_columns(names_transform) + columns = [ + label for label in outcome.columns if label not in spec.struct.fields ] - select_expr = [".name", dot_value] - if not_dot_value: - select_expr.extend(not_dot_value) + outcome = outcome.explode(columns) + outcome = outcome.unnest(".value") - return spec.unnest("extract").select(select_expr) + return outcome def _data_checks_pivot_longer( @@ -620,7 +668,6 @@ def _data_checks_pivot_longer( values_to, names_sep, names_pattern, - names_transform, ) -> tuple: """ This function majorly does type checks on the passed arguments. @@ -630,57 +677,24 @@ def _data_checks_pivot_longer( Type annotations are not provided because this function is where type checking happens. """ - - def _check_type(arg_name: str, arg_value: Any): - """ - Raise if argument is not a valid type - """ - - def _check_type_single(entry): - if ( - not isinstance(entry, str) - and not cs.is_selector(entry) - and not isinstance(entry, pl.Expr) - ): - raise TypeError( - f"The argument passed to the {arg_name} parameter " - "should be a type that is supported in the polars' " - "select function." - ) - - if isinstance(arg_value, (list, tuple)): - for entry in arg_value: - _check_type_single(entry=entry) - else: - _check_type_single(entry=arg_value) - - if (index is None) and (column_names is None): - column_names = df.columns - index = [] - elif (index is not None) and (column_names is not None): - _check_type(arg_name="index", arg_value=index) - index = df.select(index).columns - _check_type(arg_name="column_names", arg_value=column_names) - column_names = df.select(column_names).columns - - elif (index is None) and (column_names is not None): - _check_type(arg_name="column_names", arg_value=column_names) - column_names = df.select(column_names).columns - index = df.select(pl.exclude(column_names)).columns - - elif (index is not None) and (column_names is None): - _check_type(arg_name="index", arg_value=index) - index = df.select(index).columns - column_names = df.select(pl.exclude(index)).columns - - check("names_to", names_to, [list, tuple, str]) - if isinstance(names_to, (list, tuple)): + if isinstance(names_to, str): + names_to = [names_to] + elif isinstance(names_to, (list, tuple)): uniques = set() for word in names_to: - check(f"'{word}' in names_to", word, [str]) + if not isinstance(word, str): + raise TypeError( + f"'{word}' in names_to should be a string type; " + f"instead got type {type(word).__name__}" + ) if (word in uniques) and (word != ".value"): raise ValueError(f"'{word}' is duplicated in names_to.") uniques.add(word) + else: + raise TypeError( + "names_to should be a string, list, or tuple; " + f"instead got type {type(names_to).__name__}" + ) if names_sep and names_pattern: raise ValueError( @@ -690,11 +704,24 @@ def _check_type_single(entry): if names_sep is not None: check("names_sep", names_sep, [str]) - if names_pattern is not None: + else: check("names_pattern", names_pattern, [str]) check("values_to", values_to, [str]) + if (index is None) and (column_names is None): + column_names = df.columns + index = [] + elif (index is None) and (column_names is not None): + column_names = df.select(column_names).columns + index = df.select(pl.exclude(column_names)).columns + elif (index is not None) and (column_names is None): + index = df.select(index).columns + column_names = df.select(pl.exclude(index)).columns + else: + index = df.select(index).columns + column_names = df.select(column_names).columns + return ( df, index, @@ -703,5 +730,4 @@ def _check_type_single(entry): values_to, names_sep, names_pattern, - names_transform, ) diff --git a/tests/polars/functions/test_pivot_longer_polars.py b/tests/polars/functions/test_pivot_longer_polars.py index de43db0d7..d2942c9fc 100644 --- a/tests/polars/functions/test_pivot_longer_polars.py +++ b/tests/polars/functions/test_pivot_longer_polars.py @@ -19,25 +19,9 @@ def df_checks(): ) -def test_type_index(df_checks): - """Raise TypeError if wrong type is provided for the index.""" - msg = "The argument passed to the index parameter " - msg += "should be a type that is supported in the.+" - with pytest.raises(TypeError, match=msg): - df_checks.pivot_longer(index=2007, names_sep="_") - - -def test_type_column_names(df_checks): - """Raise TypeError if wrong type is provided for column_names.""" - msg = "The argument passed to the column_names parameter " - msg += "should be a type that is supported in the.+" - with pytest.raises(TypeError, match=msg): - df_checks.pivot_longer(column_names=2007, names_sep="_") - - def test_type_names_to(df_checks): """Raise TypeError if wrong type is provided for names_to.""" - msg = "names_to should be one of .+" + msg = "names_to should be a string, list, or tuple.+" with pytest.raises(TypeError, match=msg): df_checks.pivot_longer(names_to=2007, names_sep="_") @@ -90,38 +74,6 @@ def test_values_to_wrong_type(df_checks): df_checks.pivot_longer(values_to={"salvo"}, names_sep="_") -def test_pivot_index_only(df_checks): - """Test output if only index is passed.""" - result = df_checks.pivot_longer( - index=["famid", "birth"], - names_to="dim", - values_to="num", - ) - - actual = df_checks.melt( - id_vars=["famid", "birth"], variable_name="dim", value_name="num" - ) - - assert_frame_equal(result, actual, check_column_order=False) - - -def test_pivot_column_only(df_checks): - """Test output if only column_names is passed.""" - result = df_checks.pivot_longer( - column_names=["ht1", "ht2"], - names_to="dim", - values_to="num", - ) - - actual = df_checks.melt( - id_vars=["famid", "birth"], - variable_name="dim", - value_name="num", - ) - - assert_frame_equal(result, actual, check_column_order=False) - - def test_names_to_names_pattern_len(df_checks): """ " Raise ValueError @@ -161,12 +113,16 @@ def test_names_pat_str(df_checks): Test output when names_pattern is a string, and .value is present. """ - result = df_checks.pivot_longer( - column_names=cs.starts_with("ht"), - names_to=(".value", "age"), - names_pattern="(.+)(.)", - names_transform=pl.col("age").cast(pl.Int64), - ).sort(by=pl.all()) + result = ( + df_checks.pivot_longer( + index=["famid", "birth"], + names_to=(".value", "age"), + names_pattern="(.+)(.)", + names_transform=pl.col("age").cast(pl.Int64), + ) + .select("famid", "birth", "age", "ht") + .sort(by=pl.all()) + ) actual = [ {"famid": 1, "birth": 1, "age": 1, "ht": 2.8}, @@ -190,20 +146,7 @@ def test_names_pat_str(df_checks): ] actual = pl.DataFrame(actual).sort(by=pl.all()) - assert_frame_equal( - result, actual, check_dtype=False, check_column_order=False - ) - - -def test_no_column_names(df_checks): - """ - Test output if all the columns - are assigned to the index parameter. - """ - assert_frame_equal( - df_checks.pivot_longer(index=pl.all()), - df_checks, - ) + assert_frame_equal(result, actual) @pytest.fixture @@ -310,23 +253,31 @@ def test_df(): def test_names_pattern_dot_value(test_df): """Test output for names_pattern and .value.""" - result = test_df.pivot_longer( - column_names=pl.all(), - names_to=["set", ".value"], - names_pattern="(.+)_(.+)", - ).sort(by=["loc", "lat", "long"]) - assert_frame_equal(result, actual, check_column_order=False) + result = ( + test_df.pivot_longer( + column_names=cs.all(), + names_to=["set", ".value"], + names_pattern="(.+)_(.+)", + ) + .sort(by=["loc", "lat", "long"]) + .select("set", "loc", "lat", "long") + ) + assert_frame_equal(result, actual) def test_names_sep_dot_value(test_df): """Test output for names_pattern and .value.""" - result = test_df.pivot_longer( - column_names=pl.all(), - names_to=["set", ".value"], - names_sep="_", - ).sort(by=["loc", "lat", "long"]) - assert_frame_equal(result, actual, check_column_order=False) + result = ( + test_df.pivot_longer( + column_names=cs.all(), + names_to=["set", ".value"], + names_sep="_", + ) + .sort(by=["loc", "lat", "long"]) + .select("set", "loc", "lat", "long") + ) + assert_frame_equal(result, actual) @pytest.fixture @@ -388,7 +339,7 @@ def test_not_dot_value_sep2(not_dot_value): "country", variable_name="event", value_name="score" ) - assert_frame_equal(result, actual, check_column_order=False) + assert_frame_equal(result, actual) def test_not_dot_value_pattern(not_dot_value): @@ -460,7 +411,7 @@ def test_multiple_dot_value(): actual = pl.DataFrame(actual).sort(by=pl.all()) - assert_frame_equal(result, actual, check_column_order=False) + assert_frame_equal(result, actual) @pytest.fixture @@ -482,7 +433,7 @@ def test_multiple_dot_value2(single_val): index="id", names_to=(".value", ".value"), names_pattern="(.)(.)" ) - assert_frame_equal(result, single_val, check_column_order=False) + assert_frame_equal(result, single_val) actual3 = [ @@ -506,7 +457,7 @@ def test_names_pattern_single_column(single_val): "id", names_to=".value", names_pattern="(.)." ) - assert_frame_equal(result, actual3, check_column_order=False) + assert_frame_equal(result.sort(by=pl.all()), actual3.sort(by=pl.all())) def test_names_pattern_single_column_not_dot_value(single_val): @@ -515,12 +466,11 @@ def test_names_pattern_single_column_not_dot_value(single_val): """ result = single_val.pivot_longer( index="id", column_names="x1", names_to="yA", names_pattern="(.+)" - ) + ).select("id", "yA", "value") assert_frame_equal( result, single_val.melt(id_vars="id", value_vars="x1", variable_name="yA"), - check_column_order=False, ) @@ -528,14 +478,15 @@ def test_names_pattern_single_column_not_dot_value1(single_val): """ Test output if names_to is not '.value'. """ - result = single_val.select("x1").pivot_longer( - names_to="yA", names_pattern="(.+)" + result = ( + single_val.select("x1") + .pivot_longer(names_to="yA", names_pattern="(.+)") + .select("yA", "value") ) assert_frame_equal( result, single_val.select("x1").melt(variable_name="yA"), - check_column_order=False, ) @@ -592,4 +543,4 @@ def test_names_pattern_nulls_in_data(df_null): actual = pl.DataFrame(actual).sort(by=pl.all()) - assert_frame_equal(result, actual, check_column_order=False) + assert_frame_equal(result, actual)