Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,69 @@ func (c *queryConverter) BuildNotExpr(expr elastic.Query) (elastic.Query, error)
if expr == nil {
return nil, nil
}
return elastic.NewBoolQuery().MustNot(expr), nil
if bq, ok := expr.(*boolQuery); ok && len(bq.shouldClauses) > 0 {
// !(a || b) == !a && !b
ret := newBoolQuery()
ret.mustNotClauses = bq.shouldClauses
return ret, nil
}
return newBoolQuery().MustNot(expr), nil
}

func (c *queryConverter) BuildAndExpr(exprs ...elastic.Query) (elastic.Query, error) {
var reusableBoolQuery *boolQuery
validExprs := make([]elastic.Query, 0, len(exprs))
for _, e := range exprs {
if e != nil {
if e == nil {
continue
}
if bq, ok := e.(*boolQuery); !ok || len(bq.filterClauses)+len(bq.mustNotClauses) == 0 {
validExprs = append(validExprs, e)
} else if reusableBoolQuery == nil {
reusableBoolQuery = bq
} else {
reusableBoolQuery.Filter(bq.filterClauses...).MustNot(bq.mustNotClauses...)
}
}
if reusableBoolQuery != nil {
reusableBoolQuery.Filter(validExprs...)
return reusableBoolQuery, nil
}
if len(validExprs) == 0 {
return nil, nil
}
if len(validExprs) == 1 {
return validExprs[0], nil
}
return elastic.NewBoolQuery().Filter(validExprs...), nil
return newBoolQuery().Filter(validExprs...), nil
}

func (c *queryConverter) BuildOrExpr(exprs ...elastic.Query) (elastic.Query, error) {
var reusableBoolQuery *boolQuery
validExprs := make([]elastic.Query, 0, len(exprs))
for _, e := range exprs {
if e != nil {
if e == nil {
continue
}
if bq, ok := e.(*boolQuery); !ok || len(bq.shouldClauses) == 0 {
validExprs = append(validExprs, e)
} else if reusableBoolQuery == nil {
reusableBoolQuery = bq
} else {
reusableBoolQuery.Should(bq.shouldClauses...)
}
}
if reusableBoolQuery != nil {
reusableBoolQuery.Should(validExprs...)
return reusableBoolQuery, nil
}
if len(validExprs) == 0 {
return nil, nil
}
if len(validExprs) == 1 {
return validExprs[0], nil
}
return elastic.NewBoolQuery().Should(validExprs...).MinimumNumberShouldMatch(1), nil
return newBoolQuery().Should(validExprs...).MinimumNumberShouldMatch(1), nil
}

func (c *queryConverter) ConvertComparisonExpr(
Expand Down Expand Up @@ -140,7 +170,7 @@ func (c *queryConverter) ConvertTextComparisonExpr(
case sqlparser.EqualStr:
return elastic.NewMatchQuery(colName, value), nil
case sqlparser.NotEqualStr:
return elastic.NewBoolQuery().MustNot(elastic.NewMatchQuery(colName, value)), nil
return newBoolQuery().MustNot(elastic.NewMatchQuery(colName, value)), nil
default:
return nil, query.NewOperatorNotSupportedError(col.Alias, col.ValueType, operator)
}
Expand All @@ -156,7 +186,7 @@ func (c *queryConverter) ConvertRangeExpr(
case sqlparser.BetweenStr:
return elastic.NewRangeQuery(colName).Gte(from).Lte(to), nil
case sqlparser.NotBetweenStr:
return elastic.NewBoolQuery().MustNot(elastic.NewRangeQuery(colName).Gte(from).Lte(to)), nil
return newBoolQuery().MustNot(elastic.NewRangeQuery(colName).Gte(from).Lte(to)), nil
default:
// This should be impossible since the query parser only calls this function with one of those
// operators strings.
Expand All @@ -175,7 +205,7 @@ func (c *queryConverter) ConvertIsExpr(
colName := col.FieldName
switch operator {
case sqlparser.IsNullStr:
return elastic.NewBoolQuery().MustNot(elastic.NewExistsQuery(colName)), nil
return newBoolQuery().MustNot(elastic.NewExistsQuery(colName)), nil
case sqlparser.IsNotNullStr:
return elastic.NewExistsQuery(colName), nil
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package elasticsearch

import (
"fmt"

"github.com/olivere/elastic/v7"
)

// This is a wrapper for elastic.BoolQuery so we can access the clauses and be able to combine
// queries and avoid nesting queries when possible.
type boolQuery struct {
mustNotClauses []elastic.Query
filterClauses []elastic.Query
shouldClauses []elastic.Query
minimumShouldMatch string
}

var _ elastic.Query = (*boolQuery)(nil)

func newBoolQuery() *boolQuery {
return &boolQuery{}
}

func (q *boolQuery) MustNot(queries ...elastic.Query) *boolQuery {
q.mustNotClauses = append(q.mustNotClauses, queries...)
return q
}

func (q *boolQuery) Filter(filters ...elastic.Query) *boolQuery {
q.filterClauses = append(q.filterClauses, filters...)
return q
}

func (q *boolQuery) Should(queries ...elastic.Query) *boolQuery {
q.shouldClauses = append(q.shouldClauses, queries...)
return q
}

func (q *boolQuery) MinimumNumberShouldMatch(minimumNumberShouldMatch int) *boolQuery {
q.minimumShouldMatch = fmt.Sprintf("%d", minimumNumberShouldMatch)
return q
}

func (q *boolQuery) Source() (any, error) {
return elastic.NewBoolQuery().
MustNot(q.mustNotClauses...).
Filter(q.filterClauses...).
Should(q.shouldClauses...).
MinimumShouldMatch(q.minimumShouldMatch).
Source()
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func TestQueryConverter_BuildParenExpr(t *testing.T) {
},
{
name: "bool query",
in: elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field", "foo")),
out: elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field", "foo")),
in: newBoolQuery().Filter(elastic.NewTermQuery("field", "foo")),
out: newBoolQuery().Filter(elastic.NewTermQuery("field", "foo")),
},
}

Expand Down Expand Up @@ -65,12 +65,23 @@ func TestQueryConverter_BuildNotExpr(t *testing.T) {
{
name: "term query",
in: elastic.NewTermQuery("field", "foo"),
out: elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("field", "foo")),
out: newBoolQuery().MustNot(elastic.NewTermQuery("field", "foo")),
},
{
name: "bool query",
in: elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field", "foo")),
out: elastic.NewBoolQuery().MustNot(elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field", "foo"))),
in: newBoolQuery().Filter(elastic.NewTermQuery("field", "foo")),
out: newBoolQuery().MustNot(newBoolQuery().Filter(elastic.NewTermQuery("field", "foo"))),
},
{
name: "not or query",
in: newBoolQuery().Should(
elastic.NewTermQuery("field1", "foo"),
elastic.NewTermQuery("field2", "bar"),
),
out: newBoolQuery().MustNot(
elastic.NewTermQuery("field1", "foo"),
elastic.NewTermQuery("field2", "bar"),
),
},
}

Expand Down Expand Up @@ -110,25 +121,47 @@ func TestQueryConverter_BuildAndExpr(t *testing.T) {
name: "two queries",
in: []elastic.Query{
elastic.NewTermQuery("field1", "foo"),
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
newBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
},
out: elastic.NewBoolQuery().Filter(
out: newBoolQuery().Filter(
elastic.NewTermQuery("field2", "bar"),
elastic.NewTermQuery("field1", "foo"),
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
),
},
{
name: "multiple queries",
in: []elastic.Query{
elastic.NewTermQuery("field1", "foo"),
nil,
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
newBoolQuery().Should(elastic.NewTermQuery("field2", "bar")).MinimumNumberShouldMatch(1),
newBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
},
out: elastic.NewBoolQuery().Filter(
out: newBoolQuery().
Filter(
elastic.NewTermQuery("field1", "foo"),
newBoolQuery().Should(elastic.NewTermQuery("field2", "bar")).MinimumNumberShouldMatch(1),
).
MustNot(elastic.NewTermQuery("field3", "zzz")),
},
{
name: "multiple queries reuse",
in: []elastic.Query{
elastic.NewTermQuery("field1", "foo"),
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
nil,
newBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
newBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
newBoolQuery().
Filter(elastic.NewTermQuery("field4", "aaa")).
MustNot(elastic.NewTermQuery("field5", "bbb")),
},
out: newBoolQuery().
Filter(
elastic.NewTermQuery("field2", "bar"),
elastic.NewTermQuery("field4", "aaa"),
elastic.NewTermQuery("field1", "foo"),
).MustNot(
elastic.NewTermQuery("field3", "zzz"),
elastic.NewTermQuery("field5", "bbb"),
),
},
}
Expand Down Expand Up @@ -169,12 +202,12 @@ func TestQueryConverter_BuildOrExpr(t *testing.T) {
name: "two queries",
in: []elastic.Query{
elastic.NewTermQuery("field1", "foo"),
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
newBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
},
out: elastic.NewBoolQuery().
out: newBoolQuery().
Should(
elastic.NewTermQuery("field1", "foo"),
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
newBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
).
MinimumNumberShouldMatch(1),
},
Expand All @@ -183,14 +216,32 @@ func TestQueryConverter_BuildOrExpr(t *testing.T) {
in: []elastic.Query{
elastic.NewTermQuery("field1", "foo"),
nil,
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
newBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
newBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
},
out: newBoolQuery().
Should(
elastic.NewTermQuery("field1", "foo"),
newBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
newBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
).
MinimumNumberShouldMatch(1),
},
{
name: "multiple queries reuse",
in: []elastic.Query{
elastic.NewTermQuery("field1", "foo"),
nil,
newBoolQuery().Should(elastic.NewTermQuery("field2", "bar")).MinimumNumberShouldMatch(1),
newBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
newBoolQuery().Should(elastic.NewTermQuery("field4", "aaa")).MinimumNumberShouldMatch(1),
},
out: elastic.NewBoolQuery().
out: newBoolQuery().
Should(
elastic.NewTermQuery("field2", "bar"),
elastic.NewTermQuery("field4", "aaa"),
elastic.NewTermQuery("field1", "foo"),
elastic.NewBoolQuery().Filter(elastic.NewTermQuery("field2", "bar")),
elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
newBoolQuery().MustNot(elastic.NewTermQuery("field3", "zzz")),
).
MinimumNumberShouldMatch(1),
},
Expand Down Expand Up @@ -262,7 +313,7 @@ func TestQueryConverter_ConvertComparisonExpr(t *testing.T) {
operator: sqlparser.NotEqualStr,
col: intCol,
value: 123,
out: elastic.NewBoolQuery().MustNot(elastic.NewTermQuery(intCol.FieldName, 123)),
out: newBoolQuery().MustNot(elastic.NewTermQuery(intCol.FieldName, 123)),
},
{
name: "operator in",
Expand All @@ -276,7 +327,7 @@ func TestQueryConverter_ConvertComparisonExpr(t *testing.T) {
operator: sqlparser.NotInStr,
col: intCol,
value: []any{123, 456},
out: elastic.NewBoolQuery().MustNot(
out: newBoolQuery().MustNot(
elastic.NewTermsQuery(intCol.FieldName, 123, 456),
),
},
Expand Down Expand Up @@ -365,7 +416,7 @@ func TestQueryConverter_ConvertKeywordComparisonExpr(t *testing.T) {
operator: sqlparser.NotEqualStr,
col: keywordCol,
value: "foo",
out: elastic.NewBoolQuery().MustNot(elastic.NewTermQuery(keywordCol.FieldName, "foo")),
out: newBoolQuery().MustNot(elastic.NewTermQuery(keywordCol.FieldName, "foo")),
},
{
name: "operator in",
Expand All @@ -379,7 +430,7 @@ func TestQueryConverter_ConvertKeywordComparisonExpr(t *testing.T) {
operator: sqlparser.NotInStr,
col: keywordCol,
value: []any{"foo", "bar"},
out: elastic.NewBoolQuery().MustNot(
out: newBoolQuery().MustNot(
elastic.NewTermsQuery(keywordCol.FieldName, "foo", "bar"),
),
},
Expand All @@ -395,7 +446,7 @@ func TestQueryConverter_ConvertKeywordComparisonExpr(t *testing.T) {
operator: sqlparser.NotStartsWithStr,
col: keywordCol,
value: "foo",
out: elastic.NewBoolQuery().MustNot(elastic.NewPrefixQuery(keywordCol.FieldName, "foo")),
out: newBoolQuery().MustNot(elastic.NewPrefixQuery(keywordCol.FieldName, "foo")),
},
{
name: "operator starts with invalid value",
Expand Down Expand Up @@ -474,7 +525,7 @@ func TestQueryConverter_ConvertKeywordListComparisonExpr(t *testing.T) {
operator: sqlparser.NotEqualStr,
col: keywordListCol,
value: "foo",
out: elastic.NewBoolQuery().MustNot(elastic.NewTermQuery(keywordListCol.FieldName, "foo")),
out: newBoolQuery().MustNot(elastic.NewTermQuery(keywordListCol.FieldName, "foo")),
},
{
name: "operator in",
Expand All @@ -488,7 +539,7 @@ func TestQueryConverter_ConvertKeywordListComparisonExpr(t *testing.T) {
operator: sqlparser.NotInStr,
col: keywordListCol,
value: []any{"foo", "bar"},
out: elastic.NewBoolQuery().MustNot(
out: newBoolQuery().MustNot(
elastic.NewTermsQuery(keywordListCol.FieldName, "foo", "bar"),
),
},
Expand Down Expand Up @@ -549,7 +600,7 @@ func TestQueryConverter_ConvertTextComparisonExpr(t *testing.T) {
operator: sqlparser.NotEqualStr,
col: textCol,
value: "foo",
out: elastic.NewBoolQuery().MustNot(elastic.NewMatchQuery(textCol.FieldName, "foo")),
out: newBoolQuery().MustNot(elastic.NewMatchQuery(textCol.FieldName, "foo")),
},
{
name: "invalid operator",
Expand Down Expand Up @@ -611,7 +662,7 @@ func TestQueryConverter_ConvertRangeExpr(t *testing.T) {
col: keywordCol,
from: "123",
to: "456",
out: elastic.NewBoolQuery().MustNot(
out: newBoolQuery().MustNot(
elastic.NewRangeQuery(keywordCol.FieldName).Gte("123").Lte("456"),
),
},
Expand Down Expand Up @@ -663,7 +714,7 @@ func TestQueryConverter_ConvertIsExpr(t *testing.T) {
name: "operator is null",
operator: sqlparser.IsNullStr,
col: keywordCol,
out: elastic.NewBoolQuery().MustNot(elastic.NewExistsQuery(keywordCol.FieldName)),
out: newBoolQuery().MustNot(elastic.NewExistsQuery(keywordCol.FieldName)),
},
{
name: "operator is not null",
Expand Down
Loading
Loading