Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds metadata and batch_index to batch_run #6318

Merged
merged 10 commits into from
Feb 14, 2025
173 changes: 133 additions & 40 deletions src/backend/base/langflow/components/helpers/batch_run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from loguru import logger

from langflow.custom import Component
from langflow.io import DataFrameInput, HandleInput, MultilineInput, Output, StrInput
from langflow.io import (
BoolInput,
DataFrameInput,
HandleInput,
MessageTextInput,
MultilineInput,
Output,
)
from langflow.schema import DataFrame

if TYPE_CHECKING:
Expand All @@ -14,8 +23,8 @@ class BatchRunComponent(Component):
display_name = "Batch Run"
description = (
"Runs a language model over each row of a DataFrame's text column and returns a new "
"DataFrame with two columns: 'text_input' (the original text) and 'model_response' "
"containing the model's response."
"DataFrame with three columns: '**text_input**' (the original text), "
"'**model_response**' (the model's response),and '**batch_index**' (the processing order)."
)
icon = "List"
beta = True
Expand All @@ -26,6 +35,7 @@ class BatchRunComponent(Component):
display_name="Language Model",
info="Connect the 'Language Model' output from your LLM component here.",
input_types=["LanguageModel"],
required=True,
),
MultilineInput(
name="system_message",
Expand All @@ -37,12 +47,23 @@ class BatchRunComponent(Component):
name="df",
display_name="DataFrame",
info="The DataFrame whose column (specified by 'column_name') we'll treat as text messages.",
required=True,
),
StrInput(
MessageTextInput(
name="column_name",
display_name="Column Name",
info="The name of the DataFrame column to treat as text messages. Default='text'.",
value="text",
required=True,
advanced=True,
),
BoolInput(
name="enable_metadata",
display_name="Enable Metadata",
info="If True, add metadata to the output DataFrame.",
value=True,
required=False,
advanced=True,
),
]

Expand All @@ -51,51 +72,123 @@ class BatchRunComponent(Component):
display_name="Batch Results",
name="batch_results",
method="run_batch",
info="A DataFrame with two columns: 'text_input' and 'model_response'.",
info="A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.",
),
]

async def run_batch(self) -> DataFrame:
"""For each row in df[column_name], combine that text with system_message, then invoke the model asynchronously.
def _create_base_row(self, text_input: str = "", model_response: str = "", batch_index: int = -1) -> dict[str, Any]:
"""Create a base row with optional metadata."""
return {
"text_input": text_input,
"model_response": model_response,
"batch_index": batch_index,
}

def _add_metadata(
self, row: dict[str, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
) -> None:
"""Add metadata to a row if enabled."""
if not self.enable_metadata:
return

if success:
row["metadata"] = {
"has_system_message": bool(system_msg),
"input_length": len(row["text_input"]),
"response_length": len(row["model_response"]),
"processing_status": "success",
}
else:
row["metadata"] = {
"error": error,
"processing_status": "failed",
}

Returns a new DataFrame of the same length, with columns 'text_input' and 'model_response'.
async def run_batch(self) -> DataFrame:
"""Process each row in df[column_name] with the language model asynchronously.

Returns:
DataFrame: A new DataFrame containing:
- text_input: The original input text
- model_response: The model's response
- batch_index: The processing order
- metadata: Additional processing information

Raises:
ValueError: If the specified column is not found in the DataFrame
TypeError: If the model is not compatible or input types are wrong
"""
model: Runnable = self.model
system_msg = self.system_message or ""
df: DataFrame = self.df
col_name = self.column_name or "text"

# Validate inputs first
if not isinstance(df, DataFrame):
msg = f"Expected DataFrame input, got {type(df)}"
raise TypeError(msg)

if col_name not in df.columns:
msg = f"Column '{col_name}' not found in the DataFrame."
msg = f"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}"
raise ValueError(msg)

# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()

# Prepare the batch of conversations
conversations = [
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
if system_msg
else [{"role": "user", "content": text}]
for text in user_texts
]
model = model.with_config(
{
"run_name": self.display_name,
"project_name": self.get_project_name(),
"callbacks": self.get_langchain_callbacks(),
}
)

responses = await model.abatch(conversations)

# Build the final data, each row has 'text_input' + 'model_response'
rows = []
for original_text, response in zip(user_texts, responses, strict=False):
resp_text = response.content if hasattr(response, "content") else str(response)

row = {"text_input": original_text, "model_response": resp_text}
rows.append(row)

# Convert to a new DataFrame
return DataFrame(rows) # Langflow DataFrame from a list of dicts
try:
# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()
total_rows = len(user_texts)

logger.info(f"Processing {total_rows} rows with batch run")

# Prepare the batch of conversations
conversations = [
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
if system_msg
else [{"role": "user", "content": text}]
for text in user_texts
]

# Configure the model with project info and callbacks
model = model.with_config(
{
"run_name": self.display_name,
"project_name": self.get_project_name(),
"callbacks": self.get_langchain_callbacks(),
}
)

# Process batches and track progress
responses_with_idx = [
(idx, response)
for idx, response in zip(
range(len(conversations)), await model.abatch(list(conversations)), strict=True
)
]

# Sort by index to maintain order
responses_with_idx.sort(key=lambda x: x[0])

# Build the final data with enhanced metadata
rows: list[dict[str, Any]] = []
for idx, response in responses_with_idx:
resp_text = response.content if hasattr(response, "content") else str(response)
row = self._create_base_row(
text_input=user_texts[idx],
model_response=resp_text,
batch_index=idx,
)
self._add_metadata(row, success=True, system_msg=system_msg)
rows.append(row)

# Log progress
if (idx + 1) % max(1, total_rows // 10) == 0:
logger.info(f"Processed {idx + 1}/{total_rows} rows")

logger.info("Batch processing completed successfully")
return DataFrame(rows)

except (KeyError, AttributeError) as e:
# Handle data structure and attribute access errors
logger.error(f"Data processing error: {e!s}")
error_row = self._create_base_row()
self._add_metadata(error_row, success=False, error=str(e))
return DataFrame([error_row])
Loading
Loading