Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Fix bug where Ray Data incorrectly emits progress bar warning #47680

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/ray/data/_internal/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def _truncate_name(self, name: str) -> str:
):
return name

if log_once("ray_data_truncate_operator_name"):
logger.warning(
f"Truncating long operator name to {self.MAX_NAME_LENGTH} characters."
"To disable this behavior, set `ray.data.DataContext.get_current()."
"DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`."
)
op_names = name.split("->")
if len(op_names) == 1:
return op_names[0]
Expand All @@ -141,6 +135,13 @@ def _truncate_name(self, name: str) -> str:
+ len(op_names[-1])
) > self.MAX_NAME_LENGTH:
truncated_op_names.append("...")
if log_once("ray_data_truncate_operator_name"):
logger.warning(
f"Truncating long operator name to {self.MAX_NAME_LENGTH} "
"characters. To disable this behavior, set "
"`ray.data.DataContext.get_current()."
"DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`."
)
break
truncated_op_names.append(op_name)
truncated_op_names.append(op_names[-1])
Expand Down Expand Up @@ -199,6 +200,9 @@ def set_description(self, name: str) -> None:
self._desc = name
self._bar.set_description(self._desc)

def get_description(self) -> str:
return self._desc

def refresh(self):
if self._bar:
self._bar.refresh()
Expand Down
47 changes: 43 additions & 4 deletions python/ray/data/tests/test_progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import logging
from unittest.mock import patch

import pytest
from pytest import fixture
Expand Down Expand Up @@ -39,7 +41,7 @@ def wrapped_close():
bar.close = wrapped_close

# Test basic usage
pb = ProgressBar("", total, "", enabled=True)
pb = ProgressBar("", total, "unit", enabled=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was broken

assert pb._bar is not None
patch_close(pb._bar)
for _ in range(total):
Expand All @@ -50,7 +52,7 @@ def wrapped_close():
assert total_at_close == total

# Test if update() exceeds the original total, the total will be updated.
pb = ProgressBar("", total, "", enabled=True)
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
Expand All @@ -62,7 +64,7 @@ def wrapped_close():
assert total_at_close == new_total

# Test that if the bar is not complete at close(), the total will be updated.
pb = ProgressBar("", total, "")
pb = ProgressBar("", total, "unit")
assert pb._bar is not None
patch_close(pb._bar)
new_total = total // 2
Expand All @@ -74,7 +76,7 @@ def wrapped_close():
assert total_at_close == new_total

# Test updating the total
pb = ProgressBar("", total, "", enabled=True)
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
Expand All @@ -84,3 +86,40 @@ def wrapped_close():
pb.update(total + 1, total)
assert pb._bar.total == total + 1
pb.close()


@pytest.mark.parametrize(
"name, expected_description, max_line_length, should_emit_warning",
[
("Op", "Op", 2, False),
("Op->Op", "Op->Op", 5, False),
("Op->Op->Op", "Op->...->Op", 9, True),
("Op->Op->Op", "Op->Op->Op", 10, False),
# Test case for https://github.com/ray-project/ray/issues/47679.
("spam", "spam", 1, False),
],
)
def test_progress_bar_truncates_chained_operators(
name,
expected_description,
max_line_length,
should_emit_warning,
caplog,
propagate_logs,
):
with patch.object(ProgressBar, "MAX_NAME_LENGTH", max_line_length):
pb = ProgressBar(name, None, "unit")

assert pb.get_description() == expected_description
if should_emit_warning:
assert any(
record.levelno == logging.WARNING
and "Truncating long operator name" in record.message
for record in caplog.records
), caplog.records


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
Comment on lines +122 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest strikes again...

Loading