@@ -18,9 +18,10 @@ use llama_cpp_2::model::LlamaModel;
18
18
use llama_cpp_2:: token:: data_array:: LlamaTokenDataArray ;
19
19
use llama_cpp_2:: token:: LlamaToken ;
20
20
use llama_cpp_2:: ggml_time_us;
21
+ use rand:: Rng ;
21
22
22
23
use std:: num:: NonZeroU32 ;
23
- use types:: { LlamaResult , ModelConfig , ModelManager , ModelState } ;
24
+ use types:: { LlamaResult , ModelManager , ModelState } ;
24
25
25
26
use std:: collections:: HashMap ;
26
27
use std:: time:: Duration ;
@@ -32,7 +33,6 @@ pub mod types;
32
33
33
34
pub fn load_model (
34
35
path : String ,
35
- model_config : ModelConfig ,
36
36
llama_backend : & LlamaBackend ,
37
37
) -> Result < LlamaModel > {
38
38
let init_params = {
@@ -56,7 +56,7 @@ pub fn load_models(models: Vec<types::ModelDefinition>) -> Result<ModelManager>
56
56
//let arc_llama_backend = Arc::new(llama_backend);
57
57
let mut loaded_models = HashMap :: new ( ) ;
58
58
for model in models {
59
- let llama_model = load_model ( model. path , model . config . clone ( ) , & llama_backend)
59
+ let llama_model = load_model ( model. path , & llama_backend)
60
60
. expect ( "failed to load model" ) ;
61
61
let model_state = types:: ModelState {
62
62
model : llama_model,
@@ -90,7 +90,7 @@ pub fn generate(
90
90
token_callback : Option < TokenCallback > ,
91
91
stops : Option < & Vec < String > >
92
92
) -> Result < LlamaResult > {
93
- let mut batch = LlamaBatch :: new ( 512 , 1 ) ;
93
+ let mut batch = LlamaBatch :: new ( 1024 , 1 ) ;
94
94
let last_index: i32 = ( tokens_list. len ( ) - 1 ) as i32 ;
95
95
for ( i, token) in ( 0_i32 ..) . zip ( tokens_list. into_iter ( ) ) {
96
96
// llama_decode will output logits only for the last token of the prompt
@@ -111,14 +111,16 @@ pub fn generate(
111
111
loop {
112
112
let mut is_last = n_cur == n_len; // Keep track of it here for the callback and use loop to save cycles
113
113
let candidates = ctx. candidates_ith ( batch. n_tokens ( ) - 1 ) ;
114
- let candidates_p = LlamaTokenDataArray :: from_iter ( candidates, false ) ;
114
+ let mut candidates_p = LlamaTokenDataArray :: from_iter ( candidates, false ) ;
115
+ ctx. sample_temp ( & mut candidates_p, 0.1 ) ; //TODO: make this a parameter with the model config object
116
+ //ctx.sample_top_p(&mut candidates_p, 0.1, 128);
117
+ ctx. sample_top_k ( & mut candidates_p, 20 , 128 ) ;
118
+ ctx. sample_typical ( & mut candidates_p, 1.1 , 128 ) ;
115
119
let new_token_id = ctx. sample_token_greedy ( candidates_p) ;
116
120
if new_token_id == model. token_eos ( ) {
117
- println ! ( "EOS token found" ) ;
118
121
is_last = true ;
119
122
}
120
123
let token_str = model. token_to_str ( new_token_id) . expect ( "That UTF8 shit" ) ; // We should make EOS a blank string
121
- println ! ( "{}" , token_str) ;
122
124
if let Some ( stops) = stops {
123
125
if stops. iter ( ) . any ( |stop| token_str. eq ( stop) ) {
124
126
is_last = true ;
@@ -136,9 +138,7 @@ pub fn generate(
136
138
if let Some ( ref token_callback) = token_callback {
137
139
token_callback ( token_str, is_last) ;
138
140
}
139
- if is_last {
140
- break ;
141
- }
141
+
142
142
batch. clear ( ) ;
143
143
batch. add ( new_token_id, n_cur, & [ 0 ] , true ) ?;
144
144
@@ -169,9 +169,12 @@ pub fn pretty_generate(
169
169
stops : & Vec < String > ,
170
170
token_callback : Option < TokenCallback >
171
171
) -> Result < LlamaResult > {
172
+ let mut rng = rand:: thread_rng ( ) ;
173
+ let random_number: u32 = rng. gen ( ) ;
174
+
172
175
let ctx_params = LlamaContextParams :: default ( )
173
- . with_n_ctx ( NonZeroU32 :: new ( 2048 ) )
174
- . with_seed ( 0 ) ;
176
+ . with_n_ctx ( NonZeroU32 :: new ( model . config . num_ctx . unwrap_or ( 2048 ) as u32 ) )
177
+ . with_seed ( random_number ) ;
175
178
176
179
let mut ctx = model. model
177
180
. new_context ( & backend, ctx_params)
0 commit comments