Skip to content

Commit

Permalink
test: add basic tests (#59)
Browse files Browse the repository at this point in the history
Signed-off-by: Sertac Ozercan <[email protected]>
  • Loading branch information
sozercan authored Dec 27, 2023
1 parent e47b7b1 commit eadf6d1
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 0 deletions.
69 changes: 69 additions & 0 deletions pkg/aikit/config/specs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package config

import (
"reflect"
"testing"

"github.com/sozercan/aikit/pkg/utils"
)

func TestNewFromBytes(t *testing.T) {
type args struct {
b []byte
}
tests := []struct {
name string
args args
want *Config
wantErr bool
}{
{
name: "valid yaml",
args: args{b: []byte(`
apiVersion: v1alpha1
runtime: avx512
backends:
- exllama
- stablediffusion
models:
- name: test
source: foo
`)},
want: &Config{
APIVersion: utils.APIv1alpha1,
Runtime: utils.RuntimeCPUAVX512,
Backends: []string{
utils.BackendExllama,
utils.BackendStableDiffusion,
},
Models: []Model{
{
Name: "test",
Source: "foo",
},
},
},
wantErr: false,
},
{
name: "invalid yaml",
args: args{b: []byte(`
foo
`)},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewFromBytes(tt.args.b)
if (err != nil) != tt.wantErr {
t.Errorf("NewFromBytes() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewFromBytes() = %v, want %v", got, tt.want)
}
})
}
}
40 changes: 40 additions & 0 deletions pkg/aikit2llb/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package aikit2llb

import (
"testing"
)

func Test_fileNameFromURL(t *testing.T) {
type args struct {
urlString string
}
tests := []struct {
name string
args args
want string
}{
{
name: "simple",
args: args{urlString: "http://foo.bar/baz"},
want: "baz",
},
{
name: "complex",
args: args{urlString: "http://foo.bar/baz.tar.gz"},
want: "baz.tar.gz",
},
{
name: "complex with path",
args: args{urlString: "http://foo.bar/baz.tar.gz?foo=bar"},
want: "baz.tar.gz",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := fileNameFromURL(tt.args.urlString); got != tt.want {
t.Errorf("fileNameFromURL() = %v, want %v", got, tt.want)
}
})
}
}

7 changes: 7 additions & 0 deletions pkg/build/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ func validateConfig(c *config.Config) error {
return errors.New("exllama only supports nvidia cuda runtime. please add 'runtime: cuda' to your aikitfile.yaml")
}

backends := []string{utils.BackendExllama, utils.BackendExllamaV2, utils.BackendStableDiffusion}
for _, b := range c.Backends {
if !slices.Contains(backends, b) {
return errors.Errorf("backend %s is not supported", b)
}
}

runtimes := []string{"", utils.RuntimeNVIDIA, utils.RuntimeCPUAVX, utils.RuntimeCPUAVX2, utils.RuntimeCPUAVX512}
if !slices.Contains(runtimes, c.Runtime) {
return errors.Errorf("runtime %s is not supported", c.Runtime)
Expand Down
126 changes: 126 additions & 0 deletions pkg/build/build_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package build

import (
"testing"

"github.com/sozercan/aikit/pkg/aikit/config"
)

func Test_validateConfig(t *testing.T) {
type args struct {
c *config.Config
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "no config",
args: args{c: &config.Config{}},
wantErr: true,
},
{
name: "unsupported api version",
args: args{c: &config.Config{
APIVersion: "v10",
}},
wantErr: true,
},
{
name: "invalid runtime",
args: args{c: &config.Config{
APIVersion: "v1",
Runtime: "foo",
}},
wantErr: true,
},
{
name: "no models",
args: args{c: &config.Config{
APIVersion: "v1alpha1",
}},
wantErr: true,
},
{
name: "valid backend",
args: args{c: &config.Config{
APIVersion: "v1alpha1",
Runtime: "cuda",
Backends: []string{"exllama"},
Models: []config.Model{
{
Name: "test",
Source: "foo",
},
},
}},
wantErr: false,
},
{
name: "invalid backend",
args: args{c: &config.Config{
APIVersion: "v1alpha1",
Backends: []string{"foo"},
Models: []config.Model{
{
Name: "test",
Source: "foo",
},
},
}},
wantErr: true,
},
{
name: "valid backend but no cuda runtime",
args: args{c: &config.Config{
APIVersion: "v1alpha1",
Backends: []string{"exllama"},
Models: []config.Model{
{
Name: "test",
Source: "foo",
},
},
}},
wantErr: true,
},
{
name: "invalid backend combination 1",
args: args{c: &config.Config{
APIVersion: "v1alpha1",
Runtime: "cuda",
Backends: []string{"exllama", "exllama2"},
Models: []config.Model{
{
Name: "test",
Source: "foo",
},
},
}},
wantErr: true,
},
{
name: "invalid backend combination 2",
args: args{c: &config.Config{
APIVersion: "v1alpha1",
Runtime: "cuda",
Backends: []string{"exllama", "stablediffusion"},
Models: []config.Model{
{
Name: "test",
Source: "foo",
},
},
}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := validateConfig(tt.args.c); (err != nil) != tt.wantErr {
t.Errorf("validateConfig() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

0 comments on commit eadf6d1

Please sign in to comment.