Skip to content

Commit

Permalink
feat: adds metadata and batch_index to batch_run (#6318)
Browse files Browse the repository at this point in the history
* Update batch_run.py

* updates to test component and fixes formatting

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: anovazzi1 <[email protected]>
Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
  • Loading branch information
4 people authored Feb 14, 2025
1 parent ec5259a commit a1967bc
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 52 deletions.
173 changes: 133 additions & 40 deletions src/backend/base/langflow/components/helpers/batch_run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from loguru import logger

from langflow.custom import Component
from langflow.io import DataFrameInput, HandleInput, MultilineInput, Output, StrInput
from langflow.io import (
BoolInput,
DataFrameInput,
HandleInput,
MessageTextInput,
MultilineInput,
Output,
)
from langflow.schema import DataFrame

if TYPE_CHECKING:
Expand All @@ -14,8 +23,8 @@ class BatchRunComponent(Component):
display_name = "Batch Run"
description = (
"Runs a language model over each row of a DataFrame's text column and returns a new "
"DataFrame with two columns: 'text_input' (the original text) and 'model_response' "
"containing the model's response."
"DataFrame with three columns: '**text_input**' (the original text), "
"'**model_response**' (the model's response),and '**batch_index**' (the processing order)."
)
icon = "List"
beta = True
Expand All @@ -26,6 +35,7 @@ class BatchRunComponent(Component):
display_name="Language Model",
info="Connect the 'Language Model' output from your LLM component here.",
input_types=["LanguageModel"],
required=True,
),
MultilineInput(
name="system_message",
Expand All @@ -37,12 +47,23 @@ class BatchRunComponent(Component):
name="df",
display_name="DataFrame",
info="The DataFrame whose column (specified by 'column_name') we'll treat as text messages.",
required=True,
),
StrInput(
MessageTextInput(
name="column_name",
display_name="Column Name",
info="The name of the DataFrame column to treat as text messages. Default='text'.",
value="text",
required=True,
advanced=True,
),
BoolInput(
name="enable_metadata",
display_name="Enable Metadata",
info="If True, add metadata to the output DataFrame.",
value=True,
required=False,
advanced=True,
),
]

Expand All @@ -51,51 +72,123 @@ class BatchRunComponent(Component):
display_name="Batch Results",
name="batch_results",
method="run_batch",
info="A DataFrame with two columns: 'text_input' and 'model_response'.",
info="A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.",
),
]

async def run_batch(self) -> DataFrame:
"""For each row in df[column_name], combine that text with system_message, then invoke the model asynchronously.
def _create_base_row(self, text_input: str = "", model_response: str = "", batch_index: int = -1) -> dict[str, Any]:
"""Create a base row with optional metadata."""
return {
"text_input": text_input,
"model_response": model_response,
"batch_index": batch_index,
}

def _add_metadata(
self, row: dict[str, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
) -> None:
"""Add metadata to a row if enabled."""
if not self.enable_metadata:
return

if success:
row["metadata"] = {
"has_system_message": bool(system_msg),
"input_length": len(row["text_input"]),
"response_length": len(row["model_response"]),
"processing_status": "success",
}
else:
row["metadata"] = {
"error": error,
"processing_status": "failed",
}

Returns a new DataFrame of the same length, with columns 'text_input' and 'model_response'.
async def run_batch(self) -> DataFrame:
"""Process each row in df[column_name] with the language model asynchronously.
Returns:
DataFrame: A new DataFrame containing:
- text_input: The original input text
- model_response: The model's response
- batch_index: The processing order
- metadata: Additional processing information
Raises:
ValueError: If the specified column is not found in the DataFrame
TypeError: If the model is not compatible or input types are wrong
"""
model: Runnable = self.model
system_msg = self.system_message or ""
df: DataFrame = self.df
col_name = self.column_name or "text"

# Validate inputs first
if not isinstance(df, DataFrame):
msg = f"Expected DataFrame input, got {type(df)}"
raise TypeError(msg)

if col_name not in df.columns:
msg = f"Column '{col_name}' not found in the DataFrame."
msg = f"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}"
raise ValueError(msg)

# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()

# Prepare the batch of conversations
conversations = [
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
if system_msg
else [{"role": "user", "content": text}]
for text in user_texts
]
model = model.with_config(
{
"run_name": self.display_name,
"project_name": self.get_project_name(),
"callbacks": self.get_langchain_callbacks(),
}
)

responses = await model.abatch(conversations)

# Build the final data, each row has 'text_input' + 'model_response'
rows = []
for original_text, response in zip(user_texts, responses, strict=False):
resp_text = response.content if hasattr(response, "content") else str(response)

row = {"text_input": original_text, "model_response": resp_text}
rows.append(row)

# Convert to a new DataFrame
return DataFrame(rows) # Langflow DataFrame from a list of dicts
try:
# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()
total_rows = len(user_texts)

logger.info(f"Processing {total_rows} rows with batch run")

# Prepare the batch of conversations
conversations = [
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
if system_msg
else [{"role": "user", "content": text}]
for text in user_texts
]

# Configure the model with project info and callbacks
model = model.with_config(
{
"run_name": self.display_name,
"project_name": self.get_project_name(),
"callbacks": self.get_langchain_callbacks(),
}
)

# Process batches and track progress
responses_with_idx = [
(idx, response)
for idx, response in zip(
range(len(conversations)), await model.abatch(list(conversations)), strict=True
)
]

# Sort by index to maintain order
responses_with_idx.sort(key=lambda x: x[0])

# Build the final data with enhanced metadata
rows: list[dict[str, Any]] = []
for idx, response in responses_with_idx:
resp_text = response.content if hasattr(response, "content") else str(response)
row = self._create_base_row(
text_input=user_texts[idx],
model_response=resp_text,
batch_index=idx,
)
self._add_metadata(row, success=True, system_msg=system_msg)
rows.append(row)

# Log progress
if (idx + 1) % max(1, total_rows // 10) == 0:
logger.info(f"Processed {idx + 1}/{total_rows} rows")

logger.info("Batch processing completed successfully")
return DataFrame(rows)

except (KeyError, AttributeError) as e:
# Handle data structure and attribute access errors
logger.error(f"Data processing error: {e!s}")
error_row = self._create_base_row()
self._add_metadata(error_row, success=False, error=str(e))
return DataFrame([error_row])
Loading

0 comments on commit a1967bc

Please sign in to comment.