Skip to content

Commit

Permalink
Merge pull request #3449 from jsiirola/tee-stream-fixes
Browse files Browse the repository at this point in the history
Resolve buffering issues in `TeeStream` and `capture_output`
  • Loading branch information
mrmundt authored Feb 4, 2025
2 parents 095a6ed + 38ff353 commit 8a50553
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 225 deletions.
174 changes: 137 additions & 37 deletions pyomo/common/tee.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@
logger = logging.getLogger(__name__)


class _SignalFlush(object):
def __init__(self, ostream, handle):
super().__setattr__('_ostream', ostream)
super().__setattr__('_handle', handle)

def flush(self):
self._ostream.flush()
self._handle.flush = True

def __getattr__(self, attr):
return getattr(self._ostream, attr)

def __setattr__(self, attr, val):
return setattr(self._ostream, attr, val)


class _AutoFlush(_SignalFlush):
def write(self, data):
self._ostream.write(data)
self.flush()

def writelines(self, data):
self._ostream.writelines(data)
self.flush()


class redirect_fd(object):
"""Redirect a file descriptor to a new file or file descriptor.
Expand Down Expand Up @@ -152,10 +178,33 @@ def __exit__(self, t, v, traceback):


class capture_output(object):
"""
Drop-in substitute for PyUtilib's capture_output.
Takes in a StringIO, file-like object, or filename and temporarily
redirects output to a string buffer.
"""Context manager to capture output sent to sys.stdout and sys.stderr
This is a drop-in substitute for PyUtilib's capture_output to
temporarily redirect output to the provided stream or file.
Parameters
----------
output : io.TextIOBase, TeeStream, str, or None
Output stream where all captured stdout/stderr data is sent. If
a ``str`` is provided, it is used as a file name and opened
(potentially overwriting any existing file). If ``None``, a
:class:`io.StringIO` object is created and used.
capture_fd : bool
If True, we will also redirect the low-level file descriptors
associated with stdout (1) and stderr (2) to the ``output``.
This is useful for capturing output emitted directly to the
process stdout / stderr by external compiled modules.
Returns
-------
io.TextIOBase
This is the output stream object where all data is sent.
"""

def __init__(self, output=None, capture_fd=False):
Expand All @@ -169,19 +218,22 @@ def __init__(self, output=None, capture_fd=False):
self.fd_redirect = None

def __enter__(self):
self.old = (sys.stdout, sys.stderr)
if isinstance(self.output, str):
self.output_stream = open(self.output, 'w')
else:
self.output_stream = self.output
self.old = (sys.stdout, sys.stderr)
self.tee = TeeStream(self.output_stream)
if isinstance(self.output, TeeStream):
self.tee = self.output
else:
self.tee = TeeStream(self.output_stream)
self.tee.__enter__()
sys.stdout = self.tee.STDOUT
sys.stderr = self.tee.STDERR
if self.capture_fd:
self.fd_redirect = (
redirect_fd(1, sys.stdout.fileno()),
redirect_fd(2, sys.stderr.fileno()),
redirect_fd(1, self.tee.STDOUT.fileno(), synchronize=False),
redirect_fd(2, self.tee.STDERR.fileno(), synchronize=False),
)
self.fd_redirect[0].__enter__()
self.fd_redirect[1].__enter__()
Expand Down Expand Up @@ -220,6 +272,7 @@ class _StreamHandle(object):
def __init__(self, mode, buffering, encoding, newline):
self.buffering = buffering
self.newlines = newline
self.flush = False
self.read_pipe, self.write_pipe = os.pipe()
if not buffering and 'b' not in mode:
# While we support "unbuffered" behavior in text mode,
Expand All @@ -233,6 +286,13 @@ def __init__(self, mode, buffering, encoding, newline):
newline=newline,
closefd=False,
)
if not self.buffering and buffering:
# We want this stream to be unbuffered, but Python doesn't
# allow it for text streams. Mock up an unbuffered stream
# using AutoFlush
self.write_file = _AutoFlush(self.write_file, self)
else:
self.write_file = _SignalFlush(self.write_file, self)
self.decoder_buffer = b''
try:
self.encoding = encoding or self.write_file.encoding
Expand Down Expand Up @@ -268,9 +328,7 @@ def close(self):
def finalize(self, ostreams):
self.decodeIncomingBuffer()
if ostreams:
# Turn off buffering for the final write
self.buffering = 0
self.writeOutputBuffer(ostreams)
self.writeOutputBuffer(ostreams, True)
os.close(self.read_pipe)

if self.output_buffer:
Expand Down Expand Up @@ -307,10 +365,10 @@ def decodeIncomingBuffer(self):
self.output_buffer += chars
self.decoder_buffer = self.decoder_buffer[raw_len:]

def writeOutputBuffer(self, ostreams):
def writeOutputBuffer(self, ostreams, flush):
if not self.encoding:
ostring, self.output_buffer = self.output_buffer, b''
elif self.buffering == 1:
elif self.buffering > 0 and not flush:
EOL = self.output_buffer.rfind(self.newlines or '\n') + 1
ostring = self.output_buffer[:EOL]
self.output_buffer = self.output_buffer[EOL:]
Expand All @@ -320,13 +378,15 @@ def writeOutputBuffer(self, ostreams):
if not ostring:
return

for stream in ostreams:
for local_stream, user_stream in ostreams:
try:
written = stream.write(ostring)
written = local_stream.write(ostring)
except:
written = 0
if written and not self.buffering:
stream.flush()
if flush or (written and not self.buffering):
local_stream.flush()
if local_stream is not user_stream:
user_stream.flush()
# Note: some derived file-like objects fail to return the
# number of characters written (and implicitly return None).
# If we get None, we will just assume that everything was
Expand All @@ -335,30 +395,47 @@ def writeOutputBuffer(self, ostreams):
logger.error(
"Output stream (%s) closed before all output was "
"written to it. The following was left in "
"the output buffer:\n\t%r" % (stream, ostring[written:])
"the output buffer:\n\t%r" % (local_stream, ostring[written:])
)


class TeeStream(object):
def __init__(self, *ostreams, encoding=None):
self.ostreams = ostreams
def __init__(self, *ostreams, encoding=None, buffering=-1):
self.ostreams = []
self.encoding = encoding
self.buffering = buffering
self._stdout = None
self._stderr = None
self._handles = []
self._active_handles = []
self._threads = []
for user_stream in ostreams:
try:
fileno = user_stream.fileno()
except:
self.ostreams.append((user_stream, user_stream))
continue
local_stream = os.fdopen(
os.dup(fileno), mode=getattr(user_stream, 'mode', None), closefd=True
)
self.ostreams.append((local_stream, user_stream))

@property
def STDOUT(self):
if self._stdout is None:
self._stdout = self.open(buffering=1)
b = self.buffering
if b == -1:
b = 1
self._stdout = self.open(buffering=b)
return self._stdout

@property
def STDERR(self):
if self._stderr is None:
self._stderr = self.open(buffering=0)
b = self.buffering
if b == -1:
b = 0
self._stderr = self.open(buffering=b)
return self._stderr

def open(self, mode='w', buffering=-1, encoding=None, newline=None):
Expand Down Expand Up @@ -422,6 +499,9 @@ def close(self, in_exception=False):
self._active_handles.clear()
self._stdout = None
self._stderr = None
for local, orig in self.ostreams:
if orig is not local:
local.close()

def __enter__(self):
return self
Expand Down Expand Up @@ -454,15 +534,21 @@ def _start(self, handle):
def _streamReader(self, handle):
while True:
new_data = os.read(handle.read_pipe, io.DEFAULT_BUFFER_SIZE)
if not new_data:
if handle.flush:
flush = True
handle.flush = False
else:
flush = False
if new_data:
handle.decoder_buffer += new_data
elif not flush:
break
handle.decoder_buffer += new_data

# At this point, we have new data sitting in the
# handle.decoder_buffer
handle.decodeIncomingBuffer()
# Now, output whatever we have decoded to the output streams
handle.writeOutputBuffer(self.ostreams)
handle.writeOutputBuffer(self.ostreams, flush)
#
# print("STREAM READER: DONE")

Expand All @@ -473,6 +559,7 @@ def _mergedReader(self):
_fast_poll_ct = _poll_rampup
new_data = '' # something not None
while handles:
flush = False
if new_data is None:
# For performance reasons, we use very aggressive
# polling at the beginning (_poll_interval) and then
Expand All @@ -492,6 +579,9 @@ def _mergedReader(self):
if _mswindows:
for handle in list(handles):
try:
if handle.flush:
flush = True
handle.flush = False
pipe = get_osfhandle(handle.read_pipe)
numAvail = PeekNamedPipe(pipe, 0)[1]
if numAvail:
Expand All @@ -500,8 +590,8 @@ def _mergedReader(self):
break
except:
handles.remove(handle)
new_data = None
if new_data is None:
new_data = '' # not None so the poll interval doesn't increase
if new_data is None and not flush:
# PeekNamedPipe is non-blocking; to avoid swamping
# the core, sleep for a "short" amount of time
time.sleep(_poll)
Expand All @@ -515,22 +605,32 @@ def _mergedReader(self):
# deadlocks when handles are added while select() is
# waiting
ready_handles = select(list(handles), noop, noop, _poll)[0]
if not ready_handles:
new_data = None
continue
if ready_handles:
handle = ready_handles[0]
new_data = os.read(handle.read_pipe, io.DEFAULT_BUFFER_SIZE)
if new_data:
handle.decoder_buffer += new_data
else:
handles.remove(handle)
new_data = '' # not None so the poll interval doesn't increase
else:
for handle in handles:
if handle.flush:
new_data = ''
break
else:
new_data = None
continue

handle = ready_handles[0]
new_data = os.read(handle.read_pipe, io.DEFAULT_BUFFER_SIZE)
if not new_data:
handles.remove(handle)
continue
handle.decoder_buffer += new_data
if handle.flush:
flush = True
handle.flush = False

# At this point, we have new data sitting in the
# handle.decoder_buffer
handle.decodeIncomingBuffer()

# Now, output whatever we have decoded to the output streams
handle.writeOutputBuffer(self.ostreams)
handle.writeOutputBuffer(self.ostreams, flush)
#
# print("MERGED READER: DONE")
Loading

0 comments on commit 8a50553

Please sign in to comment.