Skip to content

Commit

Permalink
bugfix: chunks get inserted in completion_first manner when
Browse files Browse the repository at this point in the history
completion_first = true
  • Loading branch information
cdxker committed Aug 24, 2024
1 parent 6beaf13 commit c0cf518
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 99 deletions.
150 changes: 52 additions & 98 deletions server/src/handlers/message_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,23 @@ pub async fn create_message(
.map(|message| {
let mut message = message;
if message.role == "assistant" {
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find chunks for this message")
.to_string();
if message.content.starts_with("[{") {
// This is (chunks, content)
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
} else {
// This is (content, chunks)
message.content = message
.content
.rsplit("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
}
}
message
})
Expand Down Expand Up @@ -247,8 +258,23 @@ pub async fn get_all_topic_messages(
) -> Result<HttpResponse, actix_web::Error> {
let topic_id: uuid::Uuid = messages_topic_id.into_inner();

let messages =
get_messages_for_topic_query(topic_id, dataset_org_plan_sub.dataset.id, &pool).await?;
let messages: Vec<models::Message> =
get_messages_for_topic_query(topic_id, dataset_org_plan_sub.dataset.id, &pool)
.await?
.into_iter()
.filter_map(|mut message| {
if message.content.starts_with("||[{") {
match message.content.rsplit_once("}]") {
Some((chunks, ai_message)) => {
message.content = format!("{}}}]{}", chunks, ai_message);
}
_ => return None,
}
}

Some(message)
})
.collect();

Ok(HttpResponse::Ok().json(messages))
}
Expand Down Expand Up @@ -471,12 +497,23 @@ pub async fn regenerate_message_patch(
.map(|message| {
let mut message = message;
if message.role == "assistant" {
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
if message.content.starts_with("[{") {
// This is (chunks, content)
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
} else {
// This is (content, chunks)
message.content = message
.content
.rsplit("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
}
}
message
})
Expand Down Expand Up @@ -549,90 +586,7 @@ pub async fn regenerate_message(
pool: web::Data<Pool>,
event_queue: web::Data<EventQueue>,
) -> Result<HttpResponse, actix_web::Error> {
let topic_id = data.topic_id;
let dataset_config =
DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.server_configuration.clone());

check_completion_param_validity(data.llm_options.clone())?;

let get_messages_pool = pool.clone();
let create_message_pool = pool.clone();
let dataset_id = dataset_org_plan_sub.dataset.id;

let mut previous_messages =
get_topic_messages(topic_id, dataset_id, &get_messages_pool).await?;

if previous_messages.len() < 2 {
return Err(
ServiceError::BadRequest("Not enough messages to regenerate".to_string()).into(),
);
}

if previous_messages.len() == 2 {
return stream_response(
previous_messages,
topic_id,
dataset_org_plan_sub.dataset,
create_message_pool,
event_queue,
dataset_config,
data.into_inner().into(),
)
.await;
}

// remove citations from the previous messages
previous_messages = previous_messages
.into_iter()
.map(|message| {
let mut message = message;
if message.role == "assistant" {
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
}
message
})
.collect::<Vec<models::Message>>();

let mut message_to_regenerate = None;
for message in previous_messages.iter().rev() {
if message.role == "assistant" {
message_to_regenerate = Some(message.clone());
break;
}
}

let message_id = match message_to_regenerate {
Some(message) => message.id,
None => {
return Err(ServiceError::BadRequest("No message to regenerate".to_string()).into());
}
};

let mut previous_messages_to_regenerate = Vec::new();
for message in previous_messages.iter() {
if message.id == message_id {
break;
}
previous_messages_to_regenerate.push(message.clone());
}

delete_message_query(message_id, topic_id, dataset_id, &pool).await?;

stream_response(
previous_messages_to_regenerate,
topic_id,
dataset_org_plan_sub.dataset,
create_message_pool,
event_queue,
dataset_config,
data.into_inner().into(),
)
.await
regenerate_message_patch(data, user, dataset_org_plan_sub, pool, event_queue).await
}

#[derive(Deserialize, Serialize, Debug, ToSchema)]
Expand Down
15 changes: 14 additions & 1 deletion server/src/operators/message_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,25 @@ pub async fn stream_response(
let (s, r) = unbounded::<String>();
let stream = client.chat().create_stream(parameters).await.unwrap();

let completion_first = create_message_req_payload
.llm_options
.as_ref()
.map(|x| x.completion_first)
.unwrap_or(Some(false))
.unwrap_or(false);

Arbiter::new().spawn(async move {
let chunk_v: Vec<String> = r.iter().collect();
let completion = chunk_v.join("");

let message_to_be_stored = if completion_first {
format!("{}{}", completion, chunk_metadatas_stringified)
} else {
format!("{}{}", chunk_metadatas_stringified, completion)
};

let new_message = models::Message::from_details(
format!("{}{}", chunk_metadatas_stringified, completion),
message_to_be_stored,
topic_id,
next_message_order().try_into().unwrap(),
"assistant".to_string(),
Expand Down

0 comments on commit c0cf518

Please sign in to comment.