Skip to content

Commit f85b5c3

Browse files
authored
feat: add mamba backend (#109)
Signed-off-by: Sertac Ozercan <[email protected]>
1 parent 62bbc88 commit f85b5c3

File tree

8 files changed

+117
-35
lines changed

8 files changed

+117
-35
lines changed

.github/workflows/test-docker-gpu.yaml

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ jobs:
1515
matrix:
1616
backend:
1717
- llama-cuda
18-
# - exllama
18+
# - exllama # https://github.com/sozercan/aikit/issues/94
1919
- exllama2-gptq
2020
- exllama2-exl2
21+
- mamba
2122
steps:
22-
- uses: AutoModality/action-clean@11d611e7824ef8f2fe7f05a117d1ffe4c1a090f0 # v1.1.1
23+
- name: cleanup workspace
24+
run: |
25+
rm -rf ./* || true
26+
rm -rf ./.??* || true
2327
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
2428

2529
- name: build aikit
@@ -40,8 +44,9 @@ jobs:
4044
run: docker run --name testmodel -d --rm -p 8080:8080 --gpus all testmodel:test
4145

4246
- name: run test
47+
if: matrix.backend != 'mamba'
4348
run: |
44-
result=$(curl --fail --retry 5 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
49+
result=$(curl --fail --retry 10 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
4550
"model": "llama-2-7b-chat",
4651
"messages": [{"role": "user", "content": "explain kubernetes in a sentence"}]
4752
}')
@@ -52,13 +57,30 @@ jobs:
5257
exit 1
5358
fi
5459
60+
- name: run test
61+
if: matrix.backend == 'mamba'
62+
run: |
63+
result=$(curl --fail --retry 10 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
64+
"model": "mamba-chat",
65+
"messages": [{"role": "user", "content": "explain kubernetes in a sentence"}]
66+
}')
67+
echo $result
68+
69+
choices=$(echo "$result" | jq '.choices')
70+
if [ -z "$choices" ]; then
71+
exit 1
72+
fi
73+
5574
- name: save logs
5675
if: always()
5776
run: docker logs testmodel > /tmp/docker-${{ matrix.backend }}.log
5877

5978
- run: docker stop testmodel
6079
if: always()
6180

81+
- run: docker system prune -a -f --volumes
82+
if: always()
83+
6284
- name: publish test artifacts
6385
if: always()
6486
uses: actions/upload-artifact@694cdabd8bdb0f10b2cea11669e1bf5453eed0a6 # v4.2.0

.github/workflows/test-docker.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
- name: run llama test
7979
if: matrix.backend == 'llama'
8080
run: |
81-
result=$(curl --fail --retry 5 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
81+
result=$(curl --fail --retry 10 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
8282
"model": "llama-2-7b-chat",
8383
"messages": [{"role": "user", "content": "explain kubernetes in a sentence"}]
8484
}')
@@ -92,7 +92,7 @@ jobs:
9292
- name: run stablediffusion test
9393
if: matrix.backend == 'stablediffusion'
9494
run: |
95-
result=$(curl --fail --retry 5 --retry-all-errors http://127.0.0.1:8080/v1/images/generations -H "Content-Type: application/json" -d '{
95+
result=$(curl --fail --retry 10 --retry-all-errors http://127.0.0.1:8080/v1/images/generations -H "Content-Type: application/json" -d '{
9696
"prompt": "A cute baby llama",
9797
"size": "256x256"
9898
}')

.github/workflows/test-kubernetes.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ jobs:
8484
8585
- name: run test
8686
run: |
87-
result=$(curl --fail --retry 5 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
87+
result=$(curl --fail --retry 10 --retry-all-errors http://127.0.0.1:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
8888
"model": "llama-2-7b-chat",
8989
"messages": [{"role": "user", "content": "explain kubernetes in a sentence"}]
9090
}')

pkg/aikit2llb/convert.go

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func Aikit2LLB(c *config.Config) (llb.State, *specs.Image) {
3232

3333
// install cuda if runtime is nvidia
3434
if c.Runtime == utils.RuntimeNVIDIA {
35-
merge = installCuda(c, state, merge)
35+
state, merge = installCuda(c, state, merge)
3636
}
3737

3838
// install opencv and friends if stable diffusion backend is being used
@@ -43,6 +43,8 @@ func Aikit2LLB(c *config.Config) (llb.State, *specs.Image) {
4343
merge = installExllama(c, state, merge)
4444
case utils.BackendStableDiffusion:
4545
merge = installOpenCV(state, merge)
46+
case utils.BackendMamba:
47+
merge = installMamba(state, merge)
4648
}
4749
}
4850

@@ -51,26 +53,14 @@ func Aikit2LLB(c *config.Config) (llb.State, *specs.Image) {
5153
}
5254

5355
func getBaseImage(c *config.Config) llb.State {
54-
for b := range c.Backends {
55-
switch c.Backends[b] {
56-
case utils.BackendExllama:
57-
case utils.BackendExllamaV2:
58-
return llb.Image(debianSlim)
59-
case utils.BackendStableDiffusion:
60-
return llb.Image(debianSlim)
61-
}
56+
if len(c.Backends) > 0 {
57+
return llb.Image(debianSlim)
6258
}
6359
return llb.Image(distrolessBase)
6460
}
6561

6662
func copyModels(c *config.Config, base llb.State, s llb.State) (llb.State, llb.State) {
6763
savedState := s
68-
69-
// create config file if defined
70-
if c.Config != "" {
71-
s = s.Run(shf("echo -n \"%s\" > /config.yaml", c.Config)).Root()
72-
}
73-
7464
for _, model := range c.Models {
7565
var opts []llb.HTTPOption
7666
opts = append(opts, llb.Filename(fileNameFromURL(model.Source)))
@@ -104,6 +94,12 @@ func copyModels(c *config.Config, base llb.State, s llb.State) (llb.State, llb.S
10494
}
10595
}
10696
}
97+
98+
// create config file if defined
99+
if c.Config != "" {
100+
s = s.Run(shf("echo -n \"%s\" > /config.yaml", c.Config)).Root()
101+
}
102+
107103
diff := llb.Diff(savedState, s)
108104
merge := llb.Merge([]llb.State{base, diff})
109105
return s, merge
@@ -117,18 +113,19 @@ func fileNameFromURL(urlString string) string {
117113
return path.Base(parsedURL.Path)
118114
}
119115

120-
func installCuda(c *config.Config, s llb.State, merge llb.State) llb.State {
116+
func installCuda(c *config.Config, s llb.State, merge llb.State) (llb.State, llb.State) {
121117
cudaKeyringURL := "https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb"
122118
cudaKeyring := llb.HTTP(cudaKeyringURL)
123119
s = s.File(
124120
llb.Copy(cudaKeyring, fileNameFromURL(cudaKeyringURL), "/"),
125121
llb.WithCustomName("Copying "+fileNameFromURL(cudaKeyringURL)), //nolint: goconst
126122
)
127123
s = s.Run(sh("dpkg -i cuda-keyring_1.1-1_all.deb && rm cuda-keyring_1.1-1_all.deb")).Root()
124+
125+
savedState := s
128126
// running apt-get update twice due to nvidia repo
129127
s = s.Run(sh("apt-get update && apt-get install -y ca-certificates && apt-get update"), llb.IgnoreCache).Root()
130128

131-
savedState := s
132129
// install cuda libraries
133130
if len(c.Backends) == 0 {
134131
s = s.Run(shf("apt-get install -y --no-install-recommends libcublas-%[1]s cuda-cudart-%[1]s && apt-get clean", cudaVersion)).Root()
@@ -149,20 +146,25 @@ func installCuda(c *config.Config, s llb.State, merge llb.State) llb.State {
149146

150147
s = s.Run(sh(exllamaDeps)).Root()
151148
}
149+
150+
if c.Backends[b] == utils.BackendMamba {
151+
mambaDeps := fmt.Sprintf("apt-get install -y --no-install-recommends cuda-crt-%[1]s cuda-cudart-dev-%[1]s cuda-nvcc-%[1]s && apt-get clean", cudaVersion)
152+
s = s.Run(sh(mambaDeps)).Root()
153+
}
152154
}
153155

154156
diff := llb.Diff(savedState, s)
155-
return llb.Merge([]llb.State{merge, diff})
157+
return s, llb.Merge([]llb.State{merge, diff})
156158
}
157159

158160
func installExllama(c *config.Config, s llb.State, merge llb.State) llb.State {
159-
backend := "exllama"
161+
backend := utils.BackendExllama
160162
exllamaRepo := "https://github.com/turboderp/exllama"
161163
exllamaTag := "master"
162164
for b := range c.Backends {
163165
if c.Backends[b] == utils.BackendExllamaV2 {
164166
exllamaRepo = "https://github.com/turboderp/exllamav2"
165-
backend = "exllama2"
167+
backend = utils.BackendExllamaV2
166168
exllamaTag = "v0.0.11"
167169
}
168170
}
@@ -171,7 +173,7 @@ func installExllama(c *config.Config, s llb.State, merge llb.State) llb.State {
171173
s = s.Run(sh("apt-get update && apt-get install --no-install-recommends -y git ca-certificates python3-pip python3-dev g++ && apt-get clean"), llb.IgnoreCache).Root()
172174

173175
// clone localai exllama backend only
174-
s = s.Run(shf("git clone --filter=blob:none --no-checkout %[1]s /tmp/localai/ && cd /tmp/localai && git sparse-checkout init --cone && git sparse-checkout set backend/python/%[2]s && git checkout %[3]s && rm -rf .git", localAIRepo, backend, localAIVersion)).Root()
176+
s = cloneLocalAI(s, backend)
175177

176178
// clone exllama to localai exllama backend path and install python dependencies
177179
s = s.Run(shf("git clone --depth 1 %[1]s --branch %[2]s /tmp/%[3]s && mv /tmp/%[3]s/* /tmp/localai/backend/python/%[3]s && rm -rf /tmp/%[3]s && cd /tmp/localai/backend/python/%[3]s && rm -rf .git && pip3 install grpcio protobuf typing-extensions sympy mpmath setuptools numpy --break-system-packages && pip3 install -r /tmp/localai/backend/python/%[3]s/requirements.txt --break-system-packages", exllamaRepo, exllamaTag, backend)).Root()
@@ -180,6 +182,19 @@ func installExllama(c *config.Config, s llb.State, merge llb.State) llb.State {
180182
return llb.Merge([]llb.State{merge, diff})
181183
}
182184

185+
func installMamba(s llb.State, merge llb.State) llb.State {
186+
savedState := s
187+
// libexpat1 is requirement but git is not. however libexpat1 is a dependency of git
188+
s = s.Run(sh("apt-get install --no-install-recommends -y git python3 python3-dev python3-pip libssl3 openssl && apt-get clean"), llb.IgnoreCache).Root()
189+
190+
s = cloneLocalAI(s, utils.BackendMamba)
191+
192+
s = s.Run(shf("pip3 install packaging numpy torch==2.1.0 grpcio protobuf --break-system-packages && pip3 install causal-conv1d==1.0.0 mamba-ssm==1.0.1 --break-system-packages")).Root()
193+
194+
diff := llb.Diff(savedState, s)
195+
return llb.Merge([]llb.State{merge, diff})
196+
}
197+
183198
func installOpenCV(s llb.State, merge llb.State) llb.State {
184199
savedState := s
185200
// adding debian 11 (bullseye) repo due to opencv 4.5 requirement
@@ -233,6 +248,10 @@ func addLocalAI(c *config.Config, s llb.State, merge llb.State) (llb.State, llb.
233248
return s, llb.Merge([]llb.State{merge, diff})
234249
}
235250

251+
func cloneLocalAI(s llb.State, backend string) llb.State {
252+
return s.Run(shf("git clone --filter=blob:none --no-checkout %[1]s /tmp/localai/ && cd /tmp/localai && git sparse-checkout init --cone && git sparse-checkout set backend/python/%[2]s && git checkout %[3]s && rm -rf .git", localAIRepo, backend, localAIVersion)).Root()
253+
}
254+
236255
func shf(cmd string, v ...interface{}) llb.RunOption {
237256
return llb.Args([]string{"/bin/sh", "-c", fmt.Sprintf(cmd, v...)})
238257
}

pkg/aikit2llb/image.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,19 @@ func emptyImage(c *config.Config) *specs.Image {
4646
}
4747

4848
for b := range c.Backends {
49-
if c.Backends[b] == utils.BackendExllama || c.Backends[b] == utils.BackendExllamaV2 {
49+
switch c.Backends[b] {
50+
case utils.BackendExllama, utils.BackendExllamaV2:
5051
exllamaEnv := []string{
5152
"EXTERNAL_GRPC_BACKENDS=exllama:/tmp/localai/backend/python/exllama/exllama.py,exllama2:/tmp/localai/backend/python/exllama2/exllama2_backend.py",
5253
"CUDA_HOME=/usr/local/cuda",
5354
}
5455
img.Config.Env = append(img.Config.Env, exllamaEnv...)
56+
case utils.BackendMamba:
57+
mambaEnv := []string{
58+
"EXTERNAL_GRPC_BACKENDS=mamba:/tmp/localai/backend/python/mamba/backend_mamba.py",
59+
"CUDA_HOME=/usr/local/cuda",
60+
}
61+
img.Config.Env = append(img.Config.Env, mambaEnv...)
5562
}
5663
}
5764

pkg/build/build.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,19 @@ func validateConfig(c *config.Config) error {
128128
return errors.New("no models defined")
129129
}
130130

131-
if slices.Contains(c.Backends, utils.BackendStableDiffusion) && (slices.Contains(c.Backends, utils.BackendExllama) || slices.Contains(c.Backends, utils.BackendExllamaV2)) {
132-
return errors.New("cannot specify both stablediffusion with exllama or exllama2 at this time")
131+
if len(c.Backends) > 1 {
132+
return errors.New("only one backend is supported at this time")
133133
}
134134

135-
if slices.Contains(c.Backends, utils.BackendExllama) && slices.Contains(c.Backends, utils.BackendExllamaV2) {
136-
return errors.New("cannot specify both exllama and exllamav2 at this time")
135+
if slices.Contains(c.Backends, utils.BackendStableDiffusion) && (slices.Contains(c.Backends, utils.BackendExllama) || slices.Contains(c.Backends, utils.BackendExllamaV2)) {
136+
return errors.New("cannot specify both stablediffusion with exllama or exllama2 at this time")
137137
}
138138

139-
if (slices.Contains(c.Backends, utils.BackendExllama) || slices.Contains(c.Backends, utils.BackendExllamaV2)) && c.Runtime != utils.RuntimeNVIDIA {
140-
return errors.New("exllama only supports nvidia cuda runtime. please add 'runtime: cuda' to your aikitfile.yaml")
139+
if (slices.Contains(c.Backends, utils.BackendExllama) || slices.Contains(c.Backends, utils.BackendExllamaV2) || slices.Contains(c.Backends, utils.BackendMamba)) && c.Runtime != utils.RuntimeNVIDIA {
140+
return errors.New("exllama and mamba only supports nvidia cuda runtime. please add 'runtime: cuda' to your aikitfile.yaml")
141141
}
142142

143-
backends := []string{utils.BackendExllama, utils.BackendExllamaV2, utils.BackendStableDiffusion}
143+
backends := []string{utils.BackendExllama, utils.BackendExllamaV2, utils.BackendStableDiffusion, utils.BackendMamba}
144144
for _, b := range c.Backends {
145145
if !slices.Contains(backends, b) {
146146
return errors.Errorf("backend %s is not supported", b)

pkg/utils/const.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const (
99
BackendStableDiffusion = "stablediffusion"
1010
BackendExllama = "exllama"
1111
BackendExllamaV2 = "exllama2"
12+
BackendMamba = "mamba"
1213

1314
APIv1alpha1 = "v1alpha1"
1415
)

test/aikitfile-mamba.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#syntax=aikit:test
2+
apiVersion: v1alpha1
3+
debug: true
4+
runtime: cuda
5+
backends:
6+
- mamba
7+
models:
8+
- name: mamba-chat/config.json
9+
source: https://huggingface.co/havenhq/mamba-chat/raw/d343f8ade4c870d916b362746dd23821aae132dd/config.json
10+
- name: mamba-chat/pytorch_model.bin
11+
source: https://huggingface.co/havenhq/mamba-chat/resolve/d343f8ade4c870d916b362746dd23821aae132dd/pytorch_model.bin
12+
sha256: 6751a8c3888564a90a7f759a620e2ddfc1ab2cc3e919f2cbaf7bfc41cc5f85e7
13+
- name: mamba-chat/tokenizer.json
14+
source: https://huggingface.co/havenhq/mamba-chat/raw/d343f8ade4c870d916b362746dd23821aae132dd/tokenizer.json
15+
- name: mamba-chat/tokenizer_config.json
16+
source: https://huggingface.co/havenhq/mamba-chat/raw/d343f8ade4c870d916b362746dd23821aae132dd/tokenizer_config.json
17+
config: |
18+
- name: mamba-chat
19+
backend: mamba
20+
parameters:
21+
model: /models/mamba-chat
22+
trimsuffix:
23+
- <|endoftext|>
24+
template:
25+
chat_message: |
26+
{{if eq .RoleName \"assistant\"}}<|assistant|>{{else if eq .RoleName \"system\"}}<|system|>{{else if eq .RoleName \"user\"}}<|user|>{{end}}
27+
{{if .Content}}{{.Content}}{{end}}
28+
</s>
29+
chat: |
30+
{{.Input}}
31+
<|assistant|>
32+
completion: |
33+
{{.Input}}

0 commit comments

Comments
 (0)