Skip to content

Commit 8666100

Browse files
authored
Merge pull request #12 from Wolfenheimm/feature/error-support
Custom Error Support for LoomError
2 parents cf1fff4 + b6bc6f8 commit 8666100

File tree

6 files changed

+166
-165
lines changed

6 files changed

+166
-165
lines changed

src/lib.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#![feature(once_cell_try)]
4949

5050
use std::{
51+
error::Error,
5152
fmt::{Debug, Display},
5253
marker::PhantomData,
5354
str::FromStr,
@@ -73,11 +74,11 @@ mod mock;
7374
mod tests;
7475

7576
pub use storage::TapestryChestHandler;
76-
use types::{LoomError, SummaryModelTokens, WeaveError};
77+
use types::{LoomError, SummaryModelTokens};
7778

7879
use crate::types::{PromptModelTokens, WrapperRole};
7980

80-
pub type Result<T> = std::result::Result<T, LoomError>;
81+
pub type Result<T, U> = std::result::Result<T, LoomError<U>>;
8182

8283
/// Represents a unique identifier for any arbitrary entity.
8384
///
@@ -162,6 +163,7 @@ pub trait Llm<T: Config>:
162163
type Response: Clone + Into<Option<String>> + Send;
163164
/// Type representing the parameters for a prompt.
164165
type Parameters: Debug + Clone + Send + Sync;
166+
type PromptError: Error;
165167

166168
/// The maximum number of tokens that can be processed at once by an LLM model.
167169
fn max_context_length(&self) -> Self::Tokens;
@@ -178,7 +180,7 @@ pub trait Llm<T: Config>:
178180
/// Calculates the number of tokens in a string.
179181
///
180182
/// This may vary depending on the type of tokens used by the LLM. In the case of ChatGPT, can be calculated using the [tiktoken-rs](https://github.com/zurawiki/tiktoken-rs#counting-token-length) crate.
181-
fn count_tokens(content: &str) -> Result<Self::Tokens>;
183+
fn count_tokens(content: &str) -> Result<Self::Tokens, T>;
182184
/// Prompt LLM with the supplied messages and parameters.
183185
async fn prompt(
184186
&self,
@@ -187,7 +189,7 @@ pub trait Llm<T: Config>:
187189
msgs: Vec<Self::Request>,
188190
params: &Self::Parameters,
189191
max_tokens: Self::Tokens,
190-
) -> Result<Self::Response>;
192+
) -> Result<Self::Response, T>;
191193
/// Compute cost of a message based on model.
192194
fn compute_cost(&self, prompt_tokens: Self::Tokens, response_tokens: Self::Tokens) -> f64;
193195
/// Calculate the upperbound of tokens allowed for the current [`Config::PromptModel`] before a
@@ -301,12 +303,10 @@ impl<T: Config> TapestryFragment<T> {
301303
/// Add a [`ContextMessage`] to the `context_messages` list.
302304
///
303305
/// Also increments the `context_tokens` by the number of tokens in the message.
304-
fn push_message(&mut self, msg: ContextMessage<T>) -> Result<()> {
306+
fn push_message(&mut self, msg: ContextMessage<T>) -> Result<(), T> {
305307
let tokens = T::PromptModel::count_tokens(&msg.content)?;
306308
let new_token_count = self.context_tokens.checked_add(&tokens).ok_or_else(|| {
307-
LoomError::from(WeaveError::BadConfig(
308-
"Number of tokens exceeds max tokens for model".to_string(),
309-
))
309+
LoomError::BadConfig("Number of tokens exceeds max tokens for model".to_string())
310310
})?;
311311

312312
trace!("Pushing message: {:?}, new token count: {}", msg, new_token_count);
@@ -319,7 +319,7 @@ impl<T: Config> TapestryFragment<T> {
319319
/// Add a [`ContextMessage`] to the `context_messages` list.
320320
///
321321
/// Also increments the `context_tokens` by the number of tokens in the message.
322-
fn extend_messages(&mut self, msgs: Vec<ContextMessage<T>>) -> Result<()> {
322+
fn extend_messages(&mut self, msgs: Vec<ContextMessage<T>>) -> Result<(), T> {
323323
let total_new_tokens = msgs
324324
.iter()
325325
.map(|m| T::PromptModel::count_tokens(&m.content).unwrap())
@@ -332,9 +332,7 @@ impl<T: Config> TapestryFragment<T> {
332332
trace!("Extending messages with token sum: {}", sum);
333333

334334
let new_token_count = self.context_tokens.checked_add(&sum).ok_or_else(|| {
335-
LoomError::from(WeaveError::BadConfig(
336-
"Number of tokens exceeds max tokens for model".to_string(),
337-
))
335+
LoomError::BadConfig("Number of tokens exceeds max tokens for model".to_string())
338336
})?;
339337

340338
// Update the token count and messages only if all checks pass

src/loom.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tracing::{debug, error, instrument, trace};
66
use crate::{
77
types::{
88
LoomError, PromptModelRequest, PromptModelTokens, SummaryModelTokens, VecPromptMsgsDeque,
9-
WeaveError, WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
9+
WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
1010
},
1111
Config, ContextMessage, Llm, LlmConfig, TapestryChestHandler, TapestryFragment, TapestryId,
1212
};
@@ -52,7 +52,7 @@ impl<T: Config> Loom<T> {
5252
tapestry_id: TID,
5353
instructions: String,
5454
mut msgs: Vec<ContextMessage<T>>,
55-
) -> Result<(<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool), LoomError> {
55+
) -> Result<(<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool), LoomError<T>> {
5656
let instructions_ctx_msg =
5757
Self::build_context_message(SYSTEM_ROLE.into(), instructions, None);
5858
let instructions_req_msg: PromptModelRequest<T> = instructions_ctx_msg.clone().into();
@@ -150,7 +150,7 @@ impl<T: Config> Loom<T> {
150150
trace!("Max completion tokens available: {:?}", max_completion_tokens);
151151

152152
if max_completion_tokens.is_zero() {
153-
return Err(LoomError::from(WeaveError::MaxCompletionTokensIsZero).into());
153+
return Err(LoomError::MaxCompletionTokensIsZero.into());
154154
}
155155

156156
trace!("Prompting LLM with request messages");
@@ -208,7 +208,7 @@ impl<T: Config> Loom<T> {
208208
summary_model_config: &LlmConfig<T, T::SummaryModel>,
209209
tapestry_fragment: &TapestryFragment<T>,
210210
summary_max_tokens: SummaryModelTokens<T>,
211-
) -> Result<String, LoomError> {
211+
) -> Result<String, LoomError<T>> {
212212
trace!(
213213
"Generating summary with max tokens: {:?}, for tapestry fragment: {:?}",
214214
summary_max_tokens,

src/mock.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,49 +23,52 @@ impl TapestryChestHandler<MockConfig> for MockChest {
2323
_tapestry_id: &TID,
2424
_tapestry_fragment: TapestryFragment<MockConfig>,
2525
_increment: bool,
26-
) -> crate::Result<u64> {
26+
) -> crate::Result<u64, MockConfig> {
2727
Ok(0)
2828
}
2929

3030
async fn save_tapestry_metadata<TID: TapestryId, M: Debug + Clone + Send + Sync>(
3131
&self,
3232
_tapestry_id: TID,
3333
_metadata: M,
34-
) -> crate::Result<()> {
34+
) -> crate::Result<(), MockConfig> {
3535
Ok(())
3636
}
3737

3838
async fn get_instance_index<TID: TapestryId>(
3939
&self,
4040
_tapestry_id: TID,
41-
) -> crate::Result<Option<u16>> {
41+
) -> crate::Result<Option<u16>, MockConfig> {
4242
Ok(Some(0))
4343
}
4444

4545
async fn get_tapestry_fragment<TID: TapestryId>(
4646
&self,
4747
_tapestry_id: TID,
4848
_instance: Option<u64>,
49-
) -> crate::Result<Option<TapestryFragment<MockConfig>>> {
49+
) -> crate::Result<Option<TapestryFragment<MockConfig>>, MockConfig> {
5050
Ok(Some(TapestryFragment { context_tokens: 0, context_messages: vec![] }))
5151
}
5252

5353
async fn get_tapestry_metadata<TID: TapestryId, M: DeserializeOwned>(
5454
&self,
5555
_tapestry_id: TID,
56-
) -> crate::Result<Option<M>> {
56+
) -> crate::Result<Option<M>, MockConfig> {
5757
Ok(Some(serde_json::from_str("{}").unwrap()))
5858
}
5959

60-
async fn delete_tapestry<TID: TapestryId>(&self, _tapestry_id: TID) -> crate::Result<()> {
60+
async fn delete_tapestry<TID: TapestryId>(
61+
&self,
62+
_tapestry_id: TID,
63+
) -> crate::Result<(), MockConfig> {
6164
Ok(())
6265
}
6366

6467
async fn delete_tapestry_fragment<TID: TapestryId>(
6568
&self,
6669
_tapestry_id: TID,
6770
_instance: Option<u64>,
68-
) -> crate::Result<()> {
71+
) -> crate::Result<(), MockConfig> {
6972
Ok(())
7073
}
7174
}
@@ -104,17 +107,14 @@ impl Llm<MockConfig> for MockLlm {
104107
type Parameters = ();
105108
type Request = MockLlmRequest;
106109
type Response = MockLlmResponse;
110+
type PromptError = MockPromptError;
107111

108-
fn count_tokens(content: &str) -> Result<Self::Tokens> {
112+
fn count_tokens(content: &str) -> Result<Self::Tokens, MockConfig> {
109113
let bpe = p50k_base().unwrap();
110114
let tokens = bpe.encode_with_special_tokens(&content.to_string());
111115

112116
tokens.len().try_into().map_err(|_| {
113-
LoomError::from(WeaveError::BadConfig(format!(
114-
"Number of tokens exceeds max tokens for model: {}",
115-
content
116-
)))
117-
.into()
117+
LoomError::Llm(MockPromptError::BadConfig("Token count exceeds u16".to_string()))
118118
})
119119
}
120120

@@ -133,7 +133,7 @@ impl Llm<MockConfig> for MockLlm {
133133
_msgs: Vec<Self::Request>,
134134
_params: &Self::Parameters,
135135
_max_tokens: Self::Tokens,
136-
) -> Result<Self::Response> {
136+
) -> Result<Self::Response, MockConfig> {
137137
Ok(MockLlmResponse {})
138138
}
139139

@@ -203,3 +203,9 @@ impl From<MockLlmResponse> for Option<String> {
203203
Some("TestLlmResponse".to_string())
204204
}
205205
}
206+
207+
#[derive(Debug, Clone, thiserror::Error)]
208+
pub enum MockPromptError {
209+
#[error("Bad configuration: {0}")]
210+
BadConfig(String),
211+
}

src/storage/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,22 @@ pub trait TapestryChestHandler<T: Config> {
4141
tapestry_id: &TID,
4242
tapestry_fragment: TapestryFragment<T>,
4343
increment: bool,
44-
) -> crate::Result<u64>;
44+
) -> crate::Result<u64, T>;
4545
/// Save tapestry metadata.
4646
///
4747
/// Based on application use cases, you can add aditional data for a given [`TapestryId`]
4848
async fn save_tapestry_metadata<TID: TapestryId, M: Serialize + Debug + Clone + Send + Sync>(
4949
&self,
5050
tapestry_id: TID,
5151
metadata: M,
52-
) -> crate::Result<()>;
52+
) -> crate::Result<(), T>;
5353
/// Retrieves the index of a tapestry.
5454
///
5555
/// Returns None if the tapestry does not exist.
5656
async fn get_instance_index<TID: TapestryId>(
5757
&self,
5858
tapestry_id: TID,
59-
) -> crate::Result<Option<u16>>;
59+
) -> crate::Result<Option<u16>, T>;
6060
/// Retrieves the last tapestry fragment, or a fragment at a specified instance.
6161
///
6262
/// # Parameters
@@ -73,18 +73,18 @@ pub trait TapestryChestHandler<T: Config> {
7373
&self,
7474
tapestry_id: TID,
7575
instance: Option<u64>,
76-
) -> crate::Result<Option<TapestryFragment<T>>>;
76+
) -> crate::Result<Option<TapestryFragment<T>>, T>;
7777
/// Retrieves the last tapestry metadata, or a metadata at a specified instance.
7878
async fn get_tapestry_metadata<TID: TapestryId, M: DeserializeOwned + Send + Sync>(
7979
&self,
8080
tapestry_id: TID,
81-
) -> crate::Result<Option<M>>;
81+
) -> crate::Result<Option<M>, T>;
8282
/// Deletes a tapestry and all its instances.
83-
async fn delete_tapestry<TID: TapestryId>(&self, tapestry_id: TID) -> crate::Result<()>;
83+
async fn delete_tapestry<TID: TapestryId>(&self, tapestry_id: TID) -> crate::Result<(), T>;
8484
/// Deletes a tapestry fragment.
8585
async fn delete_tapestry_fragment<TID: TapestryId>(
8686
&self,
8787
tapestry_id: TID,
8888
instance: Option<u64>,
89-
) -> crate::Result<()>;
89+
) -> crate::Result<(), T>;
9090
}

0 commit comments

Comments
 (0)