Skip to content

Commit 006df84

Browse files
committed
feature: add support for using yzma package to call llama.cpp
This adds a new llm subpackage to use the yzma package to call llama.cpp libraries directly using the FFI interface. Signed-off-by: deadprogram <[email protected]>
1 parent 7ca1d89 commit 006df84

File tree

8 files changed

+439
-0
lines changed

8 files changed

+439
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# yzmq LLM Example 🚀
2+
3+
## What Does This Example Do? 🤔
4+
5+
This example shows you how to:
6+
7+
1. Set up a local LLM client
8+
2. Generate text using a simple prompt
9+
3. Customize the LLM configuration (with some cool commented-out options)
10+
11+
## The Magic Explained ✨
12+
13+
Here's what's happening in our main function:
14+
15+
1. We create a new yzma LLM client using `yzma.New()`. This uses default settings from your environment.
16+
17+
2. We set up a context for our LLM operations.
18+
19+
3. We generate text by asking the LLM a simple question: "How many sides does a square have?"
20+
21+
4. Finally, we print the LLM's response!
22+
23+
## Cool Features to Explore 🕵️‍♀️
24+
25+
While the example uses default settings, it also shows you how to customize your LLM:
26+
27+
- There are options to set top-k, top-p, and temperature values for text generation.
28+
29+
## Running the Example 🏃‍♂️
30+
31+
Just compile and run the Go file, and you'll see the LLM's response to the square question. It's that simple!
32+
33+
## Have Fun! 🎉
34+
35+
This example is a great starting point for experimenting with local LLMs. Feel free to uncomment the additional options and play around with different configurations. Happy coding!

examples/yzma-llm-example/go.mod

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module github.com/tmc/langchaingo/examples/yzma-llm-example
2+
3+
go 1.24.4
4+
5+
replace github.com/tmc/langchaingo => ../..
6+
7+
require github.com/tmc/langchaingo v0.0.0-00010101000000-000000000000
8+
9+
require (
10+
github.com/dlclark/regexp2 v1.10.0 // indirect
11+
github.com/ebitengine/purego v0.8.4 // indirect
12+
github.com/google/uuid v1.6.0 // indirect
13+
github.com/hybridgroup/yzma v0.7.0 // indirect
14+
github.com/jupiterrider/ffi v0.5.1 // indirect
15+
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
16+
golang.org/x/sys v0.36.0 // indirect
17+
)

examples/yzma-llm-example/go.sum

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
2+
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3+
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
4+
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
5+
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
6+
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
7+
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
8+
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
9+
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
10+
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
11+
github.com/hybridgroup/yzma v0.7.0 h1:VKuIzQSeqZgK4162cCTP2HaJvYlpRJJROWoZgzW4uAU=
12+
github.com/hybridgroup/yzma v0.7.0/go.mod h1:hqcOnvdEmI0ci1UHo9AStKmTgqWIXTyEiU7ZnQz0HCU=
13+
github.com/jupiterrider/ffi v0.5.1 h1:l7ANXU+Ex33LilVa283HNaf/sTzCrrht7D05k6T6nlc=
14+
github.com/jupiterrider/ffi v0.5.1/go.mod h1:x7xdNKo8h0AmLuXfswDUBxUsd2OqUP4ekC8sCnsmbvo=
15+
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
16+
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
17+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
18+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
19+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
20+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
21+
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
22+
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
23+
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
24+
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
25+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
26+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
27+
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
28+
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
8+
"github.com/tmc/langchaingo/llms"
9+
"github.com/tmc/langchaingo/llms/yzma"
10+
)
11+
12+
const modelPath = "/home/ron/models/SmolLM2-135M-Instruct.Q2_K.gguf"
13+
14+
func main() {
15+
llm, err := yzma.New(yzma.WithModel(modelPath))
16+
if err != nil {
17+
log.Fatal(err)
18+
}
19+
20+
// Init context
21+
ctx := context.Background()
22+
23+
completion, err := llms.GenerateFromSinglePrompt(ctx, llm, "How many sides does a square have?")
24+
// Or append to default args options from global llms.Options
25+
//generateOptions := []llms.CallOption{
26+
// llms.WithTopK(10),
27+
// llms.WithTopP(0.95),
28+
// llms.WithTemperature(0.25),
29+
//}
30+
if err != nil {
31+
log.Fatal(err)
32+
}
33+
fmt.Println(completion)
34+
}

llms/yzma/llm_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package yzma
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/tmc/langchaingo/testing/llmtest"
8+
)
9+
10+
func TestLLM(t *testing.T) {
11+
testModel := os.Getenv("YZMA_TEST_MODEL")
12+
if testModel == "" {
13+
t.Skip("YZMA_TEST_MODEL not set to point to test model")
14+
}
15+
16+
llm, err := New(WithModel(testModel))
17+
if err != nil {
18+
t.Fatalf("Failed to create yzma LLM: %v", err)
19+
}
20+
defer llm.Close()
21+
22+
llmtest.TestLLM(t, llm)
23+
}

llms/yzma/options.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package yzma
2+
3+
type options struct {
4+
model string
5+
system string
6+
}
7+
8+
type Option func(*options)
9+
10+
// WithModel sets the model to use.
11+
func WithModel(model string) Option {
12+
return func(opts *options) {
13+
opts.model = model
14+
}
15+
}
16+
17+
// WithSystemPrompt sets the system prompt.
18+
func WithSystemPrompt(p string) Option {
19+
return func(opts *options) {
20+
opts.system = p
21+
}
22+
}

llms/yzma/yzma.go

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
package yzma
2+
3+
import (
4+
"context"
5+
"errors"
6+
"os"
7+
8+
"github.com/hybridgroup/yzma/pkg/llama"
9+
"github.com/tmc/langchaingo/llms"
10+
)
11+
12+
const (
13+
defaultTemperature = 0.8
14+
defaultTopK = 40
15+
defaultTopP = 0.9
16+
)
17+
18+
// LLM is a yzma local implementation wrapper to call directly to llama.cpp libs using the FFI interface.
19+
type LLM struct {
20+
model string
21+
options options
22+
}
23+
24+
// New creates a new yzma LLM implementation.
25+
func New(opts ...Option) (*LLM, error) {
26+
o := options{}
27+
for _, opt := range opts {
28+
opt(&o)
29+
}
30+
31+
libPath := os.Getenv("YZMA_LIB")
32+
if libPath == "" {
33+
return nil, errors.New("no path to yzma libs")
34+
}
35+
36+
if err := llama.Load(""); err != nil {
37+
return nil, err
38+
}
39+
40+
llama.LogSet(llama.LogSilent())
41+
llama.Init()
42+
43+
llm := LLM{
44+
model: o.model,
45+
options: o,
46+
}
47+
48+
return &llm, nil
49+
}
50+
51+
// Close frees all resources.
52+
func (o *LLM) Close() {
53+
llama.BackendFree()
54+
}
55+
56+
// Call calls yzma with the given prompt.
57+
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
58+
return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...)
59+
}
60+
61+
// GenerateContent implements the Model interface.
62+
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
63+
opts := llms.CallOptions{}
64+
for _, opt := range options {
65+
opt(&opts)
66+
}
67+
68+
modelName := o.model
69+
if opts.Model != "" {
70+
modelName = opts.Model
71+
}
72+
73+
// TODO: allow for setting any passed model params
74+
model := llama.ModelLoadFromFile(modelName, llama.ModelDefaultParams())
75+
if model == llama.Model(0) {
76+
return nil, errors.New("unable to load model")
77+
}
78+
defer llama.ModelFree(model)
79+
80+
// TODO: allow for setting any passed context options
81+
ctxParams := llama.ContextDefaultParams()
82+
ctxParams.NCtx = uint32(4096)
83+
ctxParams.NBatch = uint32(2048)
84+
85+
lctx := llama.InitFromModel(model, ctxParams)
86+
if lctx == llama.Context(0) {
87+
return nil, errors.New("unable to init model")
88+
}
89+
90+
defer llama.Free(lctx)
91+
92+
vocab := llama.ModelGetVocab(model)
93+
94+
temperature := defaultTemperature
95+
if opts.Temperature > 0 {
96+
temperature = opts.Temperature
97+
}
98+
topK := defaultTopK
99+
if opts.TopK > 0 {
100+
topK = opts.TopK
101+
}
102+
103+
minP := 0.1
104+
105+
topP := defaultTopP
106+
if opts.TopP > 0 {
107+
topP = opts.TopP
108+
}
109+
110+
sampler := llama.SamplerChainInit(llama.SamplerChainDefaultParams())
111+
llama.SamplerChainAdd(sampler, llama.SamplerInitTopK(int32(topK)))
112+
llama.SamplerChainAdd(sampler, llama.SamplerInitTopP(float32(topP), 1))
113+
llama.SamplerChainAdd(sampler, llama.SamplerInitMinP(float32(minP), 1))
114+
llama.SamplerChainAdd(sampler, llama.SamplerInitTempExt(float32(temperature), 0, 1.0))
115+
llama.SamplerChainAdd(sampler, llama.SamplerInitDist(llama.DefaultSeed))
116+
117+
// gets the default template
118+
template := llama.ModelChatTemplate(model, "")
119+
if template == "" {
120+
template = "chatml"
121+
}
122+
123+
msgs := []llama.ChatMessage{}
124+
for _, m := range messages {
125+
p := m.Parts[0]
126+
switch pt := p.(type) {
127+
case llms.TextContent:
128+
msgs = append(msgs, llama.NewChatMessage(string(m.Role), pt.Text))
129+
default:
130+
return nil, errors.New("only support Text parts right now")
131+
}
132+
}
133+
134+
msg := chatTemplate(template, msgs, true)
135+
136+
// call once to get the size of the tokens from the prompt
137+
count := llama.Tokenize(vocab, msg, nil, true, true)
138+
139+
// now get the actual tokens
140+
tokens := make([]llama.Token, count)
141+
llama.Tokenize(vocab, msg, tokens, true, true)
142+
143+
batch := llama.BatchGetOne(tokens)
144+
145+
if llama.ModelHasEncoder(model) {
146+
llama.Encode(lctx, batch)
147+
148+
start := llama.ModelDecoderStartToken(model)
149+
if start == llama.TokenNull {
150+
start = llama.VocabBOS(vocab)
151+
}
152+
153+
batch = llama.BatchGetOne([]llama.Token{start})
154+
}
155+
156+
result := ""
157+
158+
maxTokens := int32(1024)
159+
if opts.MaxTokens > 0 {
160+
maxTokens = int32(opts.MaxTokens)
161+
}
162+
163+
for pos := int32(0); pos < maxTokens; pos += batch.NTokens {
164+
llama.Decode(lctx, batch)
165+
token := llama.SamplerSample(sampler, lctx, -1)
166+
167+
if llama.VocabIsEOG(vocab, token) {
168+
break
169+
}
170+
171+
buf := make([]byte, 64)
172+
len := llama.TokenToPiece(vocab, token, buf, 0, true)
173+
174+
result = result + string(buf[:len])
175+
batch = llama.BatchGetOne([]llama.Token{token})
176+
}
177+
178+
choices := []*llms.ContentChoice{
179+
{
180+
Content: result,
181+
},
182+
}
183+
184+
response := &llms.ContentResponse{Choices: choices}
185+
return response, nil
186+
}
187+
188+
func chatTemplate(template string, msgs []llama.ChatMessage, add bool) string {
189+
buf := make([]byte, 2048)
190+
len := llama.ChatApplyTemplate(template, msgs, add, buf)
191+
result := string(buf[:len])
192+
return result
193+
}

0 commit comments

Comments
 (0)