Skip to content

Commit 28ef5f4

Browse files
committed
feature: add support for using yzma package to call llama.cpp
This adds a new llm subpackage to use the yzma package to call llama.cpp libraries directly using the FFI interface. Signed-off-by: deadprogram <[email protected]>
1 parent 7ca1d89 commit 28ef5f4

File tree

8 files changed

+447
-0
lines changed

8 files changed

+447
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# yzmq LLM Example 🚀
2+
3+
## What Does This Example Do? 🤔
4+
5+
This example shows you how to:
6+
7+
1. Set up a local LLM client
8+
2. Generate text using a simple prompt
9+
3. Customize the LLM configuration (with some cool commented-out options)
10+
11+
## The Magic Explained ✨
12+
13+
Here's what's happening in our main function:
14+
15+
1. We create a new yzma LLM client using `yzma.New()`. This uses default settings from your environment.
16+
17+
2. We set up a context for our LLM operations.
18+
19+
3. We generate text by asking the LLM a simple question: "How many sides does a square have?"
20+
21+
4. Finally, we print the LLM's response!
22+
23+
## Cool Features to Explore 🕵️‍♀️
24+
25+
While the example uses default settings, it also shows you how to customize your LLM:
26+
27+
- There are options to set top-k, top-p, and temperature values for text generation.
28+
29+
## Running the Example 🏃‍♂️
30+
31+
Just compile and run the Go file, and you'll see the LLM's response to the square question. It's that simple!
32+
33+
## Have Fun! 🎉
34+
35+
This example is a great starting point for experimenting with local LLMs. Feel free to uncomment the additional options and play around with different configurations. Happy coding!

examples/yzma-llm-example/go.mod

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module github.com/tmc/langchaingo/examples/yzma-llm-example
2+
3+
go 1.24.4
4+
5+
replace github.com/tmc/langchaingo => ../..
6+
7+
require github.com/tmc/langchaingo v0.0.0-00010101000000-000000000000
8+
9+
require (
10+
github.com/dlclark/regexp2 v1.10.0 // indirect
11+
github.com/ebitengine/purego v0.8.4 // indirect
12+
github.com/google/uuid v1.6.0 // indirect
13+
github.com/hybridgroup/yzma v0.7.0 // indirect
14+
github.com/jupiterrider/ffi v0.5.1 // indirect
15+
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
16+
golang.org/x/sys v0.36.0 // indirect
17+
)

examples/yzma-llm-example/go.sum

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
2+
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3+
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
4+
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
5+
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
6+
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
7+
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
8+
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
9+
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
10+
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
11+
github.com/hybridgroup/yzma v0.7.0 h1:VKuIzQSeqZgK4162cCTP2HaJvYlpRJJROWoZgzW4uAU=
12+
github.com/hybridgroup/yzma v0.7.0/go.mod h1:hqcOnvdEmI0ci1UHo9AStKmTgqWIXTyEiU7ZnQz0HCU=
13+
github.com/jupiterrider/ffi v0.5.1 h1:l7ANXU+Ex33LilVa283HNaf/sTzCrrht7D05k6T6nlc=
14+
github.com/jupiterrider/ffi v0.5.1/go.mod h1:x7xdNKo8h0AmLuXfswDUBxUsd2OqUP4ekC8sCnsmbvo=
15+
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
16+
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
17+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
18+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
19+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
20+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
21+
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
22+
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
23+
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
24+
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
25+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
26+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
27+
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
28+
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
8+
"github.com/tmc/langchaingo/llms"
9+
"github.com/tmc/langchaingo/llms/yzma"
10+
)
11+
12+
const modelPath = "/home/ron/models/SmolLM2-135M-Instruct.Q2_K.gguf"
13+
14+
func main() {
15+
llm, err := yzma.New(yzma.WithModel(modelPath))
16+
if err != nil {
17+
log.Fatal(err)
18+
}
19+
20+
// Init context
21+
ctx := context.Background()
22+
23+
completion, err := llms.GenerateFromSinglePrompt(ctx, llm, "How many sides does a square have?")
24+
// Or append to default args options from global llms.Options
25+
//generateOptions := []llms.CallOption{
26+
// llms.WithTopK(10),
27+
// llms.WithTopP(0.95),
28+
// llms.WithTemperature(0.25),
29+
//}
30+
if err != nil {
31+
log.Fatal(err)
32+
}
33+
fmt.Println(completion)
34+
}

llms/yzma/llm_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package yzma
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/tmc/langchaingo/testing/llmtest"
8+
)
9+
10+
func TestLLM(t *testing.T) {
11+
testModel := os.Getenv("YZMA_TEST_MODEL")
12+
if testModel == "" {
13+
t.Skip("YZMA_TEST_MODEL not set to point to test model")
14+
}
15+
16+
llm, err := New(WithModel(testModel))
17+
if err != nil {
18+
t.Fatalf("Failed to create yzma LLM: %v", err)
19+
}
20+
defer llm.Close()
21+
22+
llmtest.TestLLM(t, llm)
23+
}

llms/yzma/options.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package yzma
2+
3+
type options struct {
4+
model string
5+
system string
6+
}
7+
8+
type Option func(*options)
9+
10+
// WithModel sets the model to use.
11+
func WithModel(model string) Option {
12+
return func(opts *options) {
13+
opts.model = model
14+
}
15+
}
16+
17+
// WithSystemPrompt sets the system prompt.
18+
func WithSystemPrompt(p string) Option {
19+
return func(opts *options) {
20+
opts.system = p
21+
}
22+
}

llms/yzma/yzma.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package yzma
2+
3+
import (
4+
"context"
5+
"errors"
6+
"os"
7+
8+
"github.com/hybridgroup/yzma/pkg/llama"
9+
"github.com/tmc/langchaingo/llms"
10+
)
11+
12+
const (
13+
defaultTemperature = 0.8
14+
defaultTopK = 40
15+
defaultTopP = 0.9
16+
)
17+
18+
// LLM is a yzma local implementation wrapper to call directly to llama.cpp libs using the FFI interface.
19+
type LLM struct {
20+
model string
21+
options options
22+
}
23+
24+
// New creates a new yzma LLM implementation.
25+
func New(opts ...Option) (*LLM, error) {
26+
o := options{}
27+
for _, opt := range opts {
28+
opt(&o)
29+
}
30+
31+
libPath := os.Getenv("YZMA_LIB")
32+
if libPath == "" {
33+
return nil, errors.New("no path to yzma libs")
34+
}
35+
36+
if err := llama.Load(""); err != nil {
37+
return nil, err
38+
}
39+
40+
llama.LogSet(llama.LogSilent())
41+
llama.Init()
42+
43+
llm := LLM{
44+
model: o.model,
45+
options: o,
46+
}
47+
48+
return &llm, nil
49+
}
50+
51+
// Close frees all resources.
52+
func (o *LLM) Close() {
53+
llama.BackendFree()
54+
}
55+
56+
// Call calls yzma with the given prompt.
57+
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
58+
return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...)
59+
}
60+
61+
// GenerateContent implements the Model interface.
62+
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
63+
opts := llms.CallOptions{}
64+
for _, opt := range options {
65+
opt(&opts)
66+
}
67+
68+
modelName := o.model
69+
if opts.Model != "" {
70+
modelName = opts.Model
71+
}
72+
73+
maxTokens := int32(1024)
74+
if opts.MaxTokens > 0 {
75+
maxTokens = int32(opts.MaxTokens)
76+
}
77+
78+
// TODO: allow for setting any passed model params
79+
model := llama.ModelLoadFromFile(modelName, llama.ModelDefaultParams())
80+
if model == llama.Model(0) {
81+
return nil, errors.New("unable to load model")
82+
}
83+
defer llama.ModelFree(model)
84+
85+
// TODO: allow for setting any passed context options
86+
ctxParams := llama.ContextDefaultParams()
87+
ctxParams.NCtx = uint32(4096)
88+
ctxParams.NBatch = uint32(2048)
89+
90+
lctx := llama.InitFromModel(model, ctxParams)
91+
if lctx == llama.Context(0) {
92+
return nil, errors.New("unable to init model")
93+
}
94+
95+
defer llama.Free(lctx)
96+
97+
vocab := llama.ModelGetVocab(model)
98+
sampler := initSampler(opts)
99+
100+
msg := chatTemplate(templateForModel(model), convertMessageContent(messages), true)
101+
102+
// call once to get the size of the tokens from the prompt
103+
count := llama.Tokenize(vocab, msg, nil, true, true)
104+
105+
// now get the actual tokens
106+
tokens := make([]llama.Token, count)
107+
llama.Tokenize(vocab, msg, tokens, true, true)
108+
109+
batch := llama.BatchGetOne(tokens)
110+
111+
if llama.ModelHasEncoder(model) {
112+
llama.Encode(lctx, batch)
113+
114+
start := llama.ModelDecoderStartToken(model)
115+
if start == llama.TokenNull {
116+
start = llama.VocabBOS(vocab)
117+
}
118+
119+
batch = llama.BatchGetOne([]llama.Token{start})
120+
}
121+
122+
result := ""
123+
124+
for pos := int32(0); pos < maxTokens; pos += batch.NTokens {
125+
llama.Decode(lctx, batch)
126+
token := llama.SamplerSample(sampler, lctx, -1)
127+
128+
if llama.VocabIsEOG(vocab, token) {
129+
break
130+
}
131+
132+
buf := make([]byte, 64)
133+
len := llama.TokenToPiece(vocab, token, buf, 0, true)
134+
135+
result = result + string(buf[:len])
136+
batch = llama.BatchGetOne([]llama.Token{token})
137+
}
138+
139+
choices := []*llms.ContentChoice{
140+
{
141+
Content: result,
142+
},
143+
}
144+
145+
response := &llms.ContentResponse{Choices: choices}
146+
return response, nil
147+
}
148+
149+
func initSampler(opts llms.CallOptions) llama.Sampler {
150+
temperature := defaultTemperature
151+
if opts.Temperature > 0 {
152+
temperature = opts.Temperature
153+
}
154+
topK := defaultTopK
155+
if opts.TopK > 0 {
156+
topK = opts.TopK
157+
}
158+
159+
minP := 0.1
160+
161+
topP := defaultTopP
162+
if opts.TopP > 0 {
163+
topP = opts.TopP
164+
}
165+
166+
sampler := llama.SamplerChainInit(llama.SamplerChainDefaultParams())
167+
llama.SamplerChainAdd(sampler, llama.SamplerInitTopK(int32(topK)))
168+
llama.SamplerChainAdd(sampler, llama.SamplerInitTopP(float32(topP), 1))
169+
llama.SamplerChainAdd(sampler, llama.SamplerInitMinP(float32(minP), 1))
170+
llama.SamplerChainAdd(sampler, llama.SamplerInitTempExt(float32(temperature), 0, 1.0))
171+
llama.SamplerChainAdd(sampler, llama.SamplerInitDist(llama.DefaultSeed))
172+
173+
return sampler
174+
}
175+
176+
func templateForModel(model llama.Model) string {
177+
template := llama.ModelChatTemplate(model, "")
178+
if template == "" {
179+
template = "chatml"
180+
}
181+
return template
182+
}
183+
184+
func convertMessageContent(msgs []llms.MessageContent) []llama.ChatMessage {
185+
chatMsgs := []llama.ChatMessage{}
186+
for _, m := range msgs {
187+
p := m.Parts[0]
188+
switch pt := p.(type) {
189+
case llms.TextContent:
190+
chatMsgs = append(chatMsgs, llama.NewChatMessage(string(m.Role), pt.Text))
191+
}
192+
}
193+
return chatMsgs
194+
}
195+
196+
func chatTemplate(template string, msgs []llama.ChatMessage, add bool) string {
197+
buf := make([]byte, 2048)
198+
len := llama.ChatApplyTemplate(template, msgs, add, buf)
199+
result := string(buf[:len])
200+
return result
201+
}

0 commit comments

Comments
 (0)