@@ -12,6 +12,7 @@ def __init__(
1212 file_encoding : str = "utf-8" ,
1313 no_log : Optional [Dict [str , Any ]] = None ,
1414 record_writes : bool = False ,
15+ slog_buffer : Optional [io .StringIO ] = None ,
1516 ) -> None :
1617 if no_log is None :
1718 self .no_log = {}
@@ -30,6 +31,13 @@ def __init__(
3031 else :
3132 self .session_log = None
3233
34+ # In order to ensure all the no_log entries get hidden properly,
35+ # we must first store everying in memory and then write out to file.
36+ # Otherwise, we might miss the data we are supposed to hide (since
37+ # the no_log data potentially spans multiple reads).
38+ if slog_buffer is None :
39+ self .slog_buffer = io .StringIO ()
40+
3341 # Ensures last write operations prior to disconnect are recorded.
3442 self .fin = False
3543
@@ -49,15 +57,30 @@ def open(self) -> None:
4957
5058 def close (self ) -> None :
5159 """Close the session_log file (if it is a file that we opened)."""
60+ self .flush ()
5261 if self .session_log and self ._session_log_close :
5362 self .session_log .close ()
5463 self .session_log = None
5564
56- def write (self , data : str ) -> None :
57- if self .session_log is not None and len (data ) > 0 :
58- # Hide the password and secret in the session_log
59- for hidden_data in self .no_log .values ():
60- data = data .replace (hidden_data , "********" )
65+ def no_log_filter (self , data : str ) -> str :
66+ """Filter content from the session_log."""
67+ for hidden_data in self .no_log .values ():
68+ data = data .replace (hidden_data , "********" )
69+ return data
70+
71+ def _read_buffer (self ) -> str :
72+ self .slog_buffer .seek (0 )
73+ data = self .slog_buffer .read ()
74+ # Once read, create a new buffer
75+ self .slog_buffer = io .StringIO ()
76+ return data
77+
78+ def flush (self ) -> None :
79+ """Force the slog_buffer to be written out to the actual file"""
80+
81+ if self .session_log is not None :
82+ data = self ._read_buffer ()
83+ data = self .no_log_filter (data )
6184
6285 if isinstance (self .session_log , io .BufferedIOBase ):
6386 self .session_log .write (write_bytes (data , encoding = self .file_encoding ))
@@ -67,4 +90,10 @@ def write(self, data: str) -> None:
6790 assert isinstance (self .session_log , io .BufferedIOBase ) or isinstance (
6891 self .session_log , io .TextIOBase
6992 )
93+
94+ # Flush the underlying file
7095 self .session_log .flush ()
96+
97+ def write (self , data : str ) -> None :
98+ if len (data ) > 0 :
99+ self .slog_buffer .write (data )
0 commit comments