Skip to content

Commit b20edc1

Browse files
forum-hsForum Gala
andauthored
fix: dia-1108 : adding fix to ensure all messages left in output topic are sent to lse and no predictions are lost (#111)
Co-authored-by: Forum Gala <[email protected]>
1 parent 0f9aa91 commit b20edc1

File tree

2 files changed

+28
-28
lines changed

2 files changed

+28
-28
lines changed

adala/runtimes/_openai.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ async def batch_to_batch(
367367
# check for errors - if any, append to outputs and continue
368368
if response.get("error"):
369369
# FIXME if we collect failed and succeeded outputs in the same list -> df, we end up with an awkward schema like this:
370-
# output error message details
371-
# ---------------------------
372-
# output1 nan nan nan
373-
# nan true message2 details2
370+
# output error message details
371+
# ---------------------------
372+
# output1 nan nan nan
373+
# nan true message2 details2
374374
# we are not going to send the error response to lse
375375
# outputs.append(response)
376376
if self.verbose:
@@ -392,7 +392,7 @@ async def batch_to_batch(
392392
# TODO: note that this doesn't work for multiple output fields e.g. `Output {output1} and Output {output2}`
393393
output_df = InternalDataFrame(outputs)
394394
# return output dataframe indexed as input batch.index, assuming outputs are in the same order as inputs
395-
return output_df.set_index('index')
395+
return output_df.set_index("index")
396396

397397
async def record_to_record(
398398
self,

server/tasks/process_file.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def process_file(agent: Agent):
4343
name="streaming_parent_task", track_started=True, bind=True, serializer="pickle"
4444
)
4545
def streaming_parent_task(
46-
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 2
46+
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 10
4747
):
4848
"""
4949
This task is used to launch the two tasks that are doing the real work, so that
@@ -155,30 +155,30 @@ async def async_process_streaming_output(
155155

156156
input_job_running = True
157157

158+
data = await consumer.getmany(timeout_ms=3000, max_records=batch_size)
159+
158160
while input_job_running:
159-
try:
160-
data = await consumer.getmany(timeout_ms=3000, max_records=batch_size)
161-
for tp, messages in data.items():
162-
if messages:
163-
logger.debug(f"Handling {messages=} in topic {tp.topic}")
164-
data = [msg.value for msg in messages]
165-
result_handler(data)
166-
logger.debug(
167-
f"Handled {len(messages)} messages in topic {tp.topic}"
168-
)
169-
else:
170-
logger.debug(f"No messages in topic {tp.topic}")
171-
172-
if not data:
173-
logger.info(f"No messages in any topic")
174-
finally:
175-
job = process_file_streaming.AsyncResult(input_job_id)
176-
# TODO no way to recover here if connection to main app is lost, job will be stuck at "PENDING" so this will loop forever
177-
if job.status in ["SUCCESS", "FAILURE", "REVOKED"]:
178-
input_job_running = False
179-
logger.info(f"Input job done, stopping output job")
161+
for tp, messages in data.items():
162+
if messages:
163+
logger.debug(f"Handling {messages=} in topic {tp.topic}")
164+
data = [msg.value for msg in messages]
165+
result_handler(data)
166+
logger.debug(f"Handled {len(messages)} messages in topic {tp.topic}")
180167
else:
181-
logger.info(f"Input job still running, keeping output job running")
168+
logger.debug(f"No messages in topic {tp.topic}")
169+
170+
if not data:
171+
logger.info(f"No messages in any topic")
172+
173+
job = process_file_streaming.AsyncResult(input_job_id)
174+
# we are getting packets from the output topic here to check if its empty and continue processing if its not
175+
data = await consumer.getmany(timeout_ms=3000, max_records=batch_size)
176+
# TODO no way to recover here if connection to main app is lost, job will be stuck at "PENDING" so this will loop forever
177+
if job.status in ["SUCCESS", "FAILURE", "REVOKED"] and len(data.items()) == 0:
178+
input_job_running = False
179+
logger.info(f"Input job done, stopping output job")
180+
else:
181+
logger.info(f"Input job still running, keeping output job running")
182182

183183
await consumer.stop()
184184

0 commit comments

Comments
 (0)