Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions contracts/database/driver/conditions.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ const (
WhereTypeJsonContains
WhereTypeJsonContainsKey
WhereTypeJsonLength
// WhereRelation used for where cause with relation subqueries, like WhereHas, OrWhereHas etc.
WhereRelation
)

type Conditions struct {
Expand Down Expand Up @@ -40,9 +42,10 @@ type Join struct {
}

type Where struct {
Query any
Args []any
Type WhereType
Or bool
IsNot bool
Query any
Args []any
Type WhereType
Relation string
Or bool
IsNot bool
}
2 changes: 2 additions & 0 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ type Query interface {
Where(query any, args ...any) Query
// WhereBetween adds a "where column between x and y" clause to the query.
WhereBetween(column string, x, y any) Query
// WhereHas add a relationship count / exists condition to the query with where clauses.
WhereHas(relation string, callback func(query Query) Query, args ...any) Query
// WhereIn adds a "where column in" clause to the query.
WhereIn(column string, values []any) Query
// WhereJsonContains add a "where JSON contains" clause to the query.
Expand Down
33 changes: 33 additions & 0 deletions database/gorm/operator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package gorm

type Operator = string

const (
gt Operator = ">"
gte Operator = ">="
eq Operator = "="
lte Operator = "<="
lt Operator = "<"
)

func isAnyOperator(s any) (Operator, bool) {
o, ok := s.(string)
if !ok {
return "", false
}

if isOperator(o) {
return o, true
}

return "", false
}

func isOperator(s string) bool {
switch s {
case gt, gte, eq, lte, lt:
return true
default:
return false
}
}
108 changes: 104 additions & 4 deletions database/gorm/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/spf13/cast"
gormio "gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"

"github.com/goravel/framework/contracts/config"
contractsdatabase "github.com/goravel/framework/contracts/database"
Expand Down Expand Up @@ -41,6 +42,8 @@ type Query struct {
mutex sync.Mutex
}

type subqueryCallback = func(contractsorm.Query) contractsorm.Query

func NewQuery(
ctx context.Context,
config config.Config,
Expand Down Expand Up @@ -765,7 +768,7 @@ func (r *Query) Scan(dest any) error {
return query.instance.Scan(dest).Error
}

func (r *Query) Scopes(funcs ...func(contractsorm.Query) contractsorm.Query) contractsorm.Query {
func (r *Query) Scopes(funcs ...subqueryCallback) contractsorm.Query {
conditions := r.conditions
conditions.scopes = deep.Append(r.conditions.scopes, funcs...)

Expand Down Expand Up @@ -1026,6 +1029,21 @@ func (r *Query) WhereNotNull(column string) contractsorm.Query {
return r.Where(fmt.Sprintf("%s IS NOT NULL", column))
}

func (r *Query) WhereHas(relation string, callback subqueryCallback, args ...any) contractsorm.Query {
subquery := NewQuery(r.ctx, r.config, r.dbConfig, r.instance, r.grammar, r.log, r.modelToObserver, nil)

if callback != nil {
subquery = callback(subquery).(*Query)
}

return r.addWhere(contractsdriver.Where{
Query: subquery,
Args: args,
Relation: relation,
Type: contractsdriver.WhereRelation,
})
}

func (r *Query) With(query string, args ...any) contractsorm.Query {
conditions := r.conditions
conditions.with = deep.Append(r.conditions.with, With{
Expand Down Expand Up @@ -1295,7 +1313,7 @@ func (r *Query) buildSharedLock(db *gormio.DB) *gormio.DB {
return db
}

func (r *Query) buildSubquery(sub func(contractsorm.Query) contractsorm.Query) *gormio.DB {
func (r *Query) buildSubquery(sub subqueryCallback) *gormio.DB {
db := r.instance.Session(&gormio.Session{NewDB: true, Initialized: true})
queryImpl := NewQuery(r.ctx, r.config, r.dbConfig, db, r.grammar, r.log, r.modelToObserver, nil)
query := sub(queryImpl)
Expand Down Expand Up @@ -1341,9 +1359,11 @@ func (r *Query) buildWhere(db *gormio.DB) *gormio.DB {
segments := strings.SplitN(item.Query.(string), " ", 2)
segments[0] = r.grammar.CompileJsonLength(segments[0])
item.Query = r.buildWherePlaceholder(strings.Join(segments, " "), item.Args...)
case contractsdriver.WhereRelation:
item.Query, item.Args = r.buildWhereRelation(item)
default:
switch query := item.Query.(type) {
case func(contractsorm.Query) contractsorm.Query:
case subqueryCallback:
item.Query = r.buildSubquery(query)
item.Args = nil
case string:
Expand Down Expand Up @@ -1372,6 +1392,86 @@ func (r *Query) buildWhere(db *gormio.DB) *gormio.DB {
return db
}

func (r *Query) buildWhereRelation(item contractsdriver.Where) (any, []any) {
var (
op Operator
count int64
)

if argsLen := len(item.Args); argsLen > 0 && argsLen != 1 {
o, ok := isAnyOperator(item.Args[0])
if !ok {
return r.instance.AddError(errors.New("the first argument should be string, it uses as operator")), []any{}
}

c, err := cast.ToInt64E(item.Args[1])
if err != nil {
return r.instance.AddError(errors.New("the second argument should be int64, it uses as count")), []any{}
}

op = o
count = c
}

subquery, err := r.relationSubquery(item.Relation, item.Query.(*Query))
if err != nil {
return r.instance.AddError(err), []any{}
}

needCountQuery := !((count == 0 && slices.Contains([]Operator{lt, lte, gt, gte}, op)) || op == "")

if !needCountQuery {
fmt.Println("exists")
modifiedQueryImpl := subquery.(*Query).buildConditions().instance
return "EXISTS (?)", []any{modifiedQueryImpl}
}

modifiedQueryImpl := subquery.Select("count(*)").(*Query)
return "(?) " + op + " ?", []any{modifiedQueryImpl.buildConditions().instance, count}
}

func (r *Query) relationSubquery(relation string, subquery contractsorm.Query) (contractsorm.Query, error) {
mSchema, err := getModelSchema(r.conditions.model, r.instance)
if err != nil {
return nil, fmt.Errorf("faild to get model schema, the model should be set before using this method. %w", err)
}

fmt.Println(relation)
rel, ok := mSchema.Relationships.Relations[relation]
if !ok {
return nil, fmt.Errorf("relation not found. %s", relation)
}
relModel := getZeroValueFromReflectType(rel.Field.FieldType)

subquery = subquery.Model(relModel)

fmt.Printf("%+v\n", subquery)

fk := rel.References[0].ForeignKey.DBName
ft := rel.FieldSchema.Table
table := mSchema.Table

switch rel.Type {
case schema.BelongsTo:
pk := rel.FieldSchema.PrioritizedPrimaryField.DBName
subquery = subquery.Where(database.QuoteConcat(ft, pk) + " = " + database.QuoteConcat(table, fk))
case schema.HasOne, schema.HasMany:
pk := mSchema.PrioritizedPrimaryField.DBName
subquery = subquery.Where(database.QuoteConcat(ft, fk) + " = " + database.QuoteConcat(table, pk))
case schema.Many2Many:
joinTable := rel.JoinTable.Table
pk := mSchema.PrioritizedPrimaryField.DBName
subquery = subquery.
Join("inner join " +
database.Quote(joinTable) +
" on " +
database.QuoteConcat(mSchema.Table, pk) +
" = " + database.QuoteConcat(joinTable, fk))
}

return subquery, nil
}

func (r *Query) buildWherePlaceholder(query string, args ...any) string {
// if query does not contain a placeholder,it might be incorrectly quoted or treated as an expression
// to avoid errors, append a manual placeholder
Expand All @@ -1393,7 +1493,7 @@ func (r *Query) buildWith(db *gormio.DB) *gormio.DB {
for _, item := range r.conditions.with {
isSet := false
if len(item.args) == 1 {
if arg, ok := item.args[0].(func(contractsorm.Query) contractsorm.Query); ok {
if arg, ok := item.args[0].(subqueryCallback); ok {
newArgs := []any{
func(tx *gormio.DB) *gormio.DB {
queryImpl := NewQuery(r.ctx, r.config, r.dbConfig, tx, r.grammar, r.log, r.modelToObserver, nil)
Expand Down
63 changes: 63 additions & 0 deletions database/gorm/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,69 @@ func TestAddWhere(t *testing.T) {
}, query1.conditions.where)
}

func TestAddWhereHas(t *testing.T) {
type Organization struct {
Model
Name string
}
type Post struct {
Model
Title string
}
type Role struct {
Model
Title string
}
type User struct {
Model
Name string
Posts []*Post
Organization *Organization
Roles []*Role `gorm:"many2many:role_user"`
}

query := (&Query{}).
Model(&User{}).
WhereHas("Posts", nil).
WhereHas("Organization", func(q contractsorm.Query) contractsorm.Query { return q.Where("name", "John") }).
WhereHas("Roles", nil, ">=", 10)

assert.Equal(t, &Query{
conditions: Conditions{
where: []contractsdriver.Where{
contractsdriver.Where{
Query: &Query{queries: make(map[string]*Query)},
Args: nil,
Relation: "Posts",
Type: contractsdriver.WhereRelation,
},
contractsdriver.Where{
Query: &Query{
queries: make(map[string]*Query),
conditions: Conditions{
where: []contractsdriver.Where{contractsdriver.Where{
Query: "name",
Args: []any{"John"},
}},
},
},
Args: nil,
Relation: "Organization",
Type: contractsdriver.WhereRelation,
},
contractsdriver.Where{
Query: &Query{queries: make(map[string]*Query)},
Args: []any{">=", 10},
Relation: "Roles",
Type: contractsdriver.WhereRelation,
},
},
model: &User{},
},
queries: make(map[string]*Query),
}, query)
}

func TestGetObserver(t *testing.T) {
query := &Query{
modelToObserver: []contractsorm.ModelToObserver{
Expand Down
23 changes: 22 additions & 1 deletion database/gorm/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package gorm

import "reflect"
import (
"reflect"

gormio "gorm.io/gorm"
"gorm.io/gorm/schema"
)

func copyStruct(dest any) reflect.Value {
t := reflect.TypeOf(dest)
Expand All @@ -18,3 +23,19 @@ func copyStruct(dest any) reflect.Value {

return v.Convert(copyDestStruct)
}

func getZeroValueFromReflectType(t reflect.Type) any {
if t.Kind() == reflect.Pointer {
return reflect.New(t.Elem()).Interface()
}
return reflect.New(t.Elem()).Interface()
}

func getModelSchema(model any, db *gormio.DB) (*schema.Schema, error) {
stmt := gormio.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
return nil, err
}
return stmt.Schema, nil
}
8 changes: 8 additions & 0 deletions support/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ func GetIDByReflect(t reflect.Type, v reflect.Value) any {

return nil
}

func Quote(name string) string {
return "`" + name + "`"
}

func QuoteConcat(table string, col string) string {
return Quote(table) + "." + Quote(col)
}
Loading