Skip to content

Commit e6a6813

Browse files
committed
Fix the random bullshit issue
1 parent 31b7356 commit e6a6813

File tree

6 files changed

+67
-15
lines changed

6 files changed

+67
-15
lines changed

Cargo.lock

+48
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ members = [
44
"api"
55
]
66

7-
resolver = "2"
7+
resolver = "2"

api/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ serde_json = "1.0"
1414
shurbai = { path = "../shurbai" }
1515
futures = "0.3.30"
1616
tokio-stream = "0.1.14"
17-
axum-streams = { version = "0.12", features=["json", "text"] }
17+
axum-streams = { version = "0.12", features=["json", "text"] }

api/src/routes.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
use std::{fmt::format, io::Write, sync::Arc};
2+
use std::sync::Arc;
33

44
use axum_streams::StreamBodyAs;
55

shurbai/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ edition = "2021"
88
[dependencies]
99
anyhow = "1.0.80"
1010
llama-cpp-2 = "0.1.27"
11+
rand = "0.8.5"
1112
serde = { version = "1.0", features = ["derive"] }

shurbai/src/lib.rs

+15-12
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ use llama_cpp_2::model::LlamaModel;
1818
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
1919
use llama_cpp_2::token::LlamaToken;
2020
use llama_cpp_2::ggml_time_us;
21+
use rand::Rng;
2122

2223
use std::num::NonZeroU32;
23-
use types::{LlamaResult, ModelConfig, ModelManager, ModelState};
24+
use types::{LlamaResult, ModelManager, ModelState};
2425

2526
use std::collections::HashMap;
2627
use std::time::Duration;
@@ -32,7 +33,6 @@ pub mod types;
3233

3334
pub fn load_model(
3435
path: String,
35-
model_config: ModelConfig,
3636
llama_backend: &LlamaBackend,
3737
) -> Result<LlamaModel> {
3838
let init_params = {
@@ -56,7 +56,7 @@ pub fn load_models(models: Vec<types::ModelDefinition>) -> Result<ModelManager>
5656
//let arc_llama_backend = Arc::new(llama_backend);
5757
let mut loaded_models = HashMap::new();
5858
for model in models {
59-
let llama_model = load_model(model.path, model.config.clone(), &llama_backend)
59+
let llama_model = load_model(model.path, &llama_backend)
6060
.expect("failed to load model");
6161
let model_state = types::ModelState {
6262
model: llama_model,
@@ -90,7 +90,7 @@ pub fn generate(
9090
token_callback: Option<TokenCallback>,
9191
stops: Option<&Vec<String>>
9292
) -> Result<LlamaResult> {
93-
let mut batch = LlamaBatch::new(512, 1);
93+
let mut batch = LlamaBatch::new(1024, 1);
9494
let last_index: i32 = (tokens_list.len() - 1) as i32;
9595
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
9696
// llama_decode will output logits only for the last token of the prompt
@@ -111,14 +111,16 @@ pub fn generate(
111111
loop {
112112
let mut is_last = n_cur == n_len; // Keep track of it here for the callback and use loop to save cycles
113113
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
114-
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
114+
let mut candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
115+
ctx.sample_temp(&mut candidates_p, 0.1); //TODO: make this a parameter with the model config object
116+
//ctx.sample_top_p(&mut candidates_p, 0.1, 128);
117+
ctx.sample_top_k(&mut candidates_p, 20, 128);
118+
ctx.sample_typical(&mut candidates_p, 1.1, 128);
115119
let new_token_id = ctx.sample_token_greedy(candidates_p);
116120
if new_token_id == model.token_eos() {
117-
println!("EOS token found");
118121
is_last = true;
119122
}
120123
let token_str = model.token_to_str(new_token_id).expect("That UTF8 shit"); // We should make EOS a blank string
121-
println!("{}", token_str);
122124
if let Some(stops) = stops {
123125
if stops.iter().any(|stop| token_str.eq(stop)) {
124126
is_last = true;
@@ -136,9 +138,7 @@ pub fn generate(
136138
if let Some(ref token_callback) = token_callback {
137139
token_callback(token_str, is_last);
138140
}
139-
if is_last {
140-
break;
141-
}
141+
142142
batch.clear();
143143
batch.add(new_token_id, n_cur, &[0], true)?;
144144

@@ -169,9 +169,12 @@ pub fn pretty_generate(
169169
stops: &Vec<String>,
170170
token_callback: Option<TokenCallback>
171171
) -> Result<LlamaResult> {
172+
let mut rng = rand::thread_rng();
173+
let random_number: u32 = rng.gen();
174+
172175
let ctx_params = LlamaContextParams::default()
173-
.with_n_ctx(NonZeroU32::new(2048))
174-
.with_seed(0);
176+
.with_n_ctx(NonZeroU32::new(model.config.num_ctx.unwrap_or(2048) as u32))
177+
.with_seed(random_number);
175178

176179
let mut ctx = model.model
177180
.new_context(&backend, ctx_params)

0 commit comments

Comments
 (0)