Skip to content

Commit b57ccf0

Browse files
[v3-0-test] Fix reading huge (XCom) resposne in TaskSDK task process (#53186) (#53194)
If you tried to send a large XCom value, it would fail in the task/child process side with this error: > RuntimeError: unable to read full response in child. (We read 36476, but expected 1310046) (The exact number that was able to read dependent on any different factors, like the OS, the current state of the socket and other things. Sometimes it would read up to 256kb fine, othertimes only 35kb as here) This is because the kernel level read-side socket buffer is full, so that was as much as the Supervisor could send. The fix is to read in a loop until we get it all. (cherry picked from commit b9620bf) Co-authored-by: Ash Berlin-Taylor <[email protected]>
1 parent 62c83d1 commit b57ccf0

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,16 @@ def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[i
228228
length = int.from_bytes(len_bytes, byteorder="big")
229229

230230
buffer = bytearray(length)
231-
nread = self.socket.recv_into(buffer)
232-
if nread != length:
233-
raise RuntimeError(
234-
f"unable to read full response in child. (We read {nread}, but expected {length})"
235-
)
236-
if nread == 0:
237-
raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
238-
239-
resp = self.resp_decoder.decode(buffer)
231+
mv = memoryview(buffer)
232+
233+
pos = 0
234+
while pos < length:
235+
nread = self.socket.recv_into(mv[pos:])
236+
if nread == 0:
237+
raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
238+
pos += nread
239+
240+
resp = self.resp_decoder.decode(mv)
240241
if maxfds:
241242
return resp, fds or []
242243
return resp

task-sdk/tests/task_sdk/execution_time/test_comms.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
import threading
2021
import uuid
2122
from socket import socketpair
2223

@@ -81,3 +82,32 @@ def test_recv_StartupDetails(self):
8182
assert msg.dag_rel_path == "/dev/null"
8283
assert msg.bundle_info == BundleInfo(name="any-name", version="any-version")
8384
assert msg.start_date == timezone.datetime(2024, 12, 1, 1)
85+
86+
def test_huge_payload(self):
87+
r, w = socketpair()
88+
89+
msg = {
90+
"type": "XComResult",
91+
"key": "a",
92+
"value": ("a" * 10 * 1024 * 1024) + "b", # A 10mb xcom value
93+
}
94+
95+
w.settimeout(1.0)
96+
bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None))
97+
98+
# Since `sendall` blocks, we need to do the send in another thread, so we can perform the read here
99+
t = threading.Thread(target=w.sendall, args=(len(bytes).to_bytes(4, byteorder="big") + bytes,))
100+
t.start()
101+
102+
decoder = CommsDecoder(socket=r, log=None)
103+
104+
try:
105+
msg = decoder._get_response()
106+
finally:
107+
t.join(2)
108+
109+
assert msg is not None
110+
111+
# It actually failed to read at all for large values, but lets just make sure we get it all
112+
assert len(msg.value) == 10 * 1024 * 1024 + 1
113+
assert msg.value[-1] == "b"

0 commit comments

Comments
 (0)