Skip to content

Commit 616bd62

Browse files
committed
Support raw SQL projections in Select and OrderBy statements.
1 parent e03ecc8 commit 616bd62

File tree

6 files changed

+238
-14
lines changed

6 files changed

+238
-14
lines changed

Diff for: src/Marten.Testing/CoreFunctionality/query_by_sql.cs

+45
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Linq;
33
using System.Threading.Tasks;
44
using Marten.Linq.MatchesSql;
5+
using Marten.Linq.SqlProjection;
56
using Marten.Testing.Documents;
67
using Marten.Testing.Harness;
78
using Shouldly;
@@ -324,6 +325,50 @@ public async Task query_with_select_in_query_async()
324325
}
325326
}
326327

328+
[Fact]
329+
public async Task query_with_select_sql_projection_async()
330+
{
331+
using (var session = theStore.OpenSession())
332+
{
333+
var u = new User {FirstName = "Jeremy", LastName = "Miller", Age = 1337};
334+
session.Store(u);
335+
session.SaveChanges();
336+
337+
#region sample_using-sql-projection-queryasync
338+
339+
var users = await session.Query<User>()
340+
.Select(x => new { Age = x.SqlProjection<int>("data->>'Age'") })
341+
.ToListAsync();
342+
var user = users.Single();
343+
344+
#endregion
345+
346+
user.Age.ShouldBe(1337);
347+
}
348+
}
349+
350+
[Fact]
351+
public async Task query_with_order_by_sql_projection_async()
352+
{
353+
using (var session = theStore.OpenSession())
354+
{
355+
var u = new User {FirstName = "Jeremy", LastName = "Miller"};
356+
session.Store(u);
357+
session.SaveChanges();
358+
359+
#region sample_using-sql-projection-queryasync
360+
361+
var users = await session.Query<User>()
362+
.OrderBy(x => x.SqlProjection<string>("data->>'FirstName'"))
363+
.ToListAsync();
364+
var user = users.Single();
365+
366+
#endregion
367+
368+
user.FirstName.ShouldBe("Jeremy");
369+
}
370+
}
371+
327372
[Fact]
328373
public async Task get_sum_of_integers_asynchronously()
329374
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System;
2+
using Marten.Linq.SqlProjection;
3+
using Shouldly;
4+
using Xunit;
5+
6+
namespace Marten.Testing.Linq.SqlProjection
7+
{
8+
public class SqlProjectionTests
9+
{
10+
[Fact]
11+
public void Throws_NotSupportedException_when_called_directly()
12+
{
13+
Should.Throw<NotSupportedException>(
14+
() => new object().SqlProjection<string>("COALESCE(d.data ->> 'UserName', ?)", "baz"));
15+
}
16+
}
17+
}

Diff for: src/Marten/Linq/Parsing/SelectTransformBuilder.cs

+89-11
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
using System.Reflection;
77
using Baseline;
88
using Marten.Linq.Fields;
9+
using Marten.Linq.SqlProjection;
910
using Remotion.Linq.Parsing;
11+
using Weasel.Postgresql.SqlGeneration;
1012

1113
namespace Marten.Linq.Parsing
1214
{
1315
internal class SelectTransformBuilder : RelinqExpressionVisitor
1416
{
1517
private TargetObject _target;
16-
private SelectedField _currentField;
18+
private BindingTarget _currentTarget;
1719

1820
public SelectTransformBuilder(Expression clause, IFieldMapping fields, ISerializer serializer)
1921
{
@@ -35,7 +37,7 @@ protected override Expression VisitNew(NewExpression expression)
3537

3638
for (var i = 0; i < parameters.Length; i++)
3739
{
38-
_currentField = _target.StartBinding(parameters[i].Name);
40+
_currentTarget = _target.StartBinding(parameters[i].Name);
3941
Visit(expression.Arguments[i]);
4042
}
4143

@@ -44,21 +46,76 @@ protected override Expression VisitNew(NewExpression expression)
4446

4547
protected override Expression VisitMember(MemberExpression node)
4648
{
47-
_currentField.Add(node.Member);
49+
_currentTarget.AddMember(node.Member);
4850
return base.VisitMember(node);
4951
}
5052

5153
protected override MemberBinding VisitMemberBinding(MemberBinding node)
5254
{
53-
_currentField = _target.StartBinding(node.Member.Name);
55+
_currentTarget = _target.StartBinding(node.Member.Name);
5456

5557
return base.VisitMemberBinding(node);
5658
}
5759

60+
protected override Expression VisitMethodCall(MethodCallExpression node)
61+
{
62+
var fragment = SqlProjectionSqlFragment.TryParse(node);
63+
if (fragment == null)
64+
{
65+
throw new NotSupportedException(
66+
$"Method {node.Method.DeclaringType?.FullName}.{node.Method.Name} is not supported.");
67+
}
68+
69+
_currentTarget.AddSqlProjection(fragment);
70+
71+
return base.VisitMethodCall(node);
72+
}
73+
74+
public class BindingTarget : TargetObject.ISetterBinding
75+
{
76+
private readonly string _name;
77+
private TargetObject.SetterBinding _field;
78+
private TargetObject.SqlProjectionBinding _sqlProjection;
79+
80+
public BindingTarget(string name)
81+
{
82+
_name = name;
83+
}
84+
85+
public void AddMember(MemberInfo memberInfo)
86+
{
87+
if (_sqlProjection != null)
88+
{
89+
throw new InvalidOperationException(
90+
"Cannot bind to a member after having bound to a sql projection");
91+
}
92+
93+
_field ??= new TargetObject.SetterBinding(_name);
94+
_field.Field.Add(memberInfo);
95+
}
96+
97+
public void AddSqlProjection(ISqlFragment sqlProjectionClause)
98+
{
99+
if (_field != null)
100+
{
101+
throw new InvalidOperationException(
102+
"Cannot bind to a sql projection after having bound to a member.");
103+
}
104+
105+
_sqlProjection = new TargetObject.SqlProjectionBinding(_name, sqlProjectionClause);
106+
}
107+
108+
public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer)
109+
{
110+
return _field?.ToJsonBuildObjectPair(mapping, serializer)
111+
?? _sqlProjection?.ToJsonBuildObjectPair(mapping, serializer)
112+
?? string.Empty;
113+
}
114+
}
58115

59116
public class TargetObject
60117
{
61-
private readonly IList<SetterBinding> _setters = new List<SetterBinding>();
118+
private readonly IList<ISetterBinding> _setters = new List<ISetterBinding>();
62119

63120
public TargetObject(Type type)
64121
{
@@ -67,12 +124,11 @@ public TargetObject(Type type)
67124

68125
public Type Type { get; }
69126

70-
public SelectedField StartBinding(string bindingName)
127+
public BindingTarget StartBinding(string bindingName)
71128
{
72-
var setter = new SetterBinding(bindingName);
73-
_setters.Add(setter);
74-
75-
return setter.Field;
129+
var bindingTarget = new BindingTarget(bindingName);
130+
_setters.Add(bindingTarget);
131+
return bindingTarget;
76132
}
77133

78134
public string ToSelectField(IFieldMapping fields, ISerializer serializer)
@@ -81,7 +137,12 @@ public string ToSelectField(IFieldMapping fields, ISerializer serializer)
81137
return $"jsonb_build_object({jsonBuildObjectArgs})";
82138
}
83139

84-
private class SetterBinding
140+
public interface ISetterBinding
141+
{
142+
string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer);
143+
}
144+
145+
public class SetterBinding: ISetterBinding
85146
{
86147
public SetterBinding(string name)
87148
{
@@ -101,6 +162,23 @@ public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serialize
101162
return $"'{Name}', {locator}";
102163
}
103164
}
165+
166+
public class SqlProjectionBinding: ISetterBinding
167+
{
168+
public SqlProjectionBinding(string name, ISqlFragment projectionFragment)
169+
{
170+
Name = name;
171+
ProjectionFragment = projectionFragment;
172+
}
173+
174+
private string Name { get; }
175+
private ISqlFragment ProjectionFragment { get; }
176+
177+
public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer)
178+
{
179+
return $"'{Name}', ({ProjectionFragment.ToSql()})";
180+
}
181+
}
104182
}
105183

106184
public class SelectedField: IEnumerable<MemberInfo>

Diff for: src/Marten/Linq/SqlGeneration/Statement.cs

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
using System;
12
using System.Collections.Generic;
23
using System.Linq;
4+
using System.Linq.Expressions;
35
using Baseline;
46
using Marten.Internal;
57
using Marten.Linq.Fields;
68
using Marten.Linq.Parsing;
9+
using Marten.Linq.SqlProjection;
710
using Weasel.Postgresql;
811
using Npgsql;
912
using Remotion.Linq.Clauses;
@@ -95,9 +98,23 @@ protected void writeWhereClause(CommandBuilder sql)
9598

9699
protected void writeOrderByFragment(CommandBuilder sql, Ordering clause)
97100
{
98-
var field = Fields.FieldFor(clause.Expression);
99-
var locator = field.ToOrderExpression(clause.Expression);
100-
sql.Append(locator);
101+
if (clause.Expression is MethodCallExpression methodCallExpression)
102+
{
103+
var sqlProjectionFragment = SqlProjectionSqlFragment.TryParse(methodCallExpression);
104+
if (sqlProjectionFragment == null)
105+
{
106+
throw new NotSupportedException(
107+
$"Method {methodCallExpression.Method.DeclaringType?.FullName}.{methodCallExpression.Method.Name} is not supported.");
108+
}
109+
110+
sqlProjectionFragment.Apply(sql);
111+
}
112+
else
113+
{
114+
var field = Fields.FieldFor(clause.Expression);
115+
var locator = field.ToOrderExpression(clause.Expression);
116+
sql.Append(locator);
117+
}
101118

102119
if (clause.OrderingDirection == OrderingDirection.Desc) sql.Append(" desc");
103120
}
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Reflection;
3+
4+
namespace Marten.Linq.SqlProjection
5+
{
6+
public static class SqlProjectionExtensions
7+
{
8+
public static readonly MethodInfo MethodInfo = typeof(SqlProjectionExtensions)
9+
.GetMethod(nameof(SqlProjection),
10+
BindingFlags.Public | BindingFlags.Static);
11+
12+
public static T SqlProjection<T>(this object doc, string sql, params object[] parameters)
13+
{
14+
throw new NotSupportedException(
15+
$"{nameof(SqlProjection)} extension method can only be used in Marten Linq queries.");
16+
}
17+
}
18+
}
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using System;
2+
using System.Linq.Expressions;
3+
using Weasel.Postgresql;
4+
using Weasel.Postgresql.SqlGeneration;
5+
6+
namespace Marten.Linq.SqlProjection
7+
{
8+
public class SqlProjectionSqlFragment : ISqlFragment
9+
{
10+
public void Apply(CommandBuilder builder)
11+
{
12+
throw new System.NotImplementedException();
13+
}
14+
15+
public bool Contains(string sqlText)
16+
{
17+
throw new System.NotImplementedException();
18+
}
19+
20+
public static ISqlFragment TryParse(MethodCallExpression node, Func<Expression, Expression> visit = null)
21+
{
22+
if (node == null)
23+
{
24+
return null;
25+
}
26+
27+
visit ??= x => x;
28+
29+
if (!node.Method.IsGenericMethod ||
30+
node.Method.GetGenericMethodDefinition() != SqlProjectionExtensions.MethodInfo)
31+
{
32+
return null;
33+
}
34+
35+
if (visit(node.Arguments[1]) is not ConstantExpression { Value: string sql })
36+
{
37+
throw new NotSupportedException("SqlProjection first parameter needs to resolve to a string");
38+
}
39+
40+
if (visit(node.Arguments[2]) is not ConstantExpression { Value: object[] sqlArguments })
41+
{
42+
throw new NotSupportedException("SqlProjection second parameter needs to resolve to an object[]");
43+
}
44+
45+
var whereFragment = new WhereFragment(sql, sqlArguments);
46+
return whereFragment;
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)