Skip to content

Commit a38f9a6

Browse files
committed
Fixed generation of discriminator comparing expressions to be translatable to database queries
1 parent 7a43356 commit a38f9a6

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs

+24-16
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ open System.Linq.Expressions
2727
open System.Runtime.InteropServices
2828
open System.Reflection
2929

30+
type private CompareDiscriminatorExpression<'T, 'D> = Expression<Func<'T, 'D, bool>>
31+
3032
/// <summary>
3133
/// Allows to specify discriminator comparison or discriminator getter
3234
/// and a function that return discriminator value depending on entity type
@@ -59,7 +61,7 @@ open System.Reflection
5961
/// </code></example>
6062
[<Struct>]
6163
type ObjectListFilterLinqOptions<'T, 'D>
62-
([<Optional>] compareDiscriminator : Expression<Func<'T, 'D, bool>> | null, [<Optional>] getDiscriminatorValue : (Type -> 'D) | null) =
64+
([<Optional>] compareDiscriminator : CompareDiscriminatorExpression<'T, 'D> | null, [<Optional>] getDiscriminatorValue : (Type -> 'D) | null) =
6365

6466
member _.CompareDiscriminator = compareDiscriminator |> ValueOption.ofObj
6567
member _.GetDiscriminatorValue = getDiscriminatorValue |> ValueOption.ofObj
@@ -74,7 +76,7 @@ type ObjectListFilterLinqOptions<'T, 'D>
7476

7577
new (getDiscriminator : Expression<Func<'T, 'D>>) =
7678
ObjectListFilterLinqOptions<'T, 'D> (ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, null)
77-
new (compareDiscriminator : Expression<Func<'T, 'D, bool>>) = ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator, null)
79+
new (compareDiscriminator : CompareDiscriminatorExpression<'T, 'D>) = ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator, null)
7880
new (getDiscriminatorValue : Type -> 'D) =
7981
ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator = null, getDiscriminatorValue = getDiscriminatorValue)
8082
new (getDiscriminator : Expression<Func<'T, 'D>>, getDiscriminatorValue : Type -> 'D) =
@@ -227,6 +229,20 @@ module ObjectListFilter =
227229
let paramExpr = Expression.PropertyOrField (param, f.FieldName)
228230
buildFilterExpr (SourceExpression paramExpr) buildTypeDiscriminatorCheck f.Value
229231

232+
type private CompareDiscriminatorExpressionVisitor<'T, 'D> (
233+
compareDiscriminator : CompareDiscriminatorExpression<'T, 'D>,
234+
param : SourceExpression,
235+
value : obj
236+
) =
237+
inherit ExpressionVisitor ()
238+
override _.VisitParameter(node) =
239+
if node = compareDiscriminator.Parameters.[0] then
240+
param.Value
241+
elif node = compareDiscriminator.Parameters.[1] then
242+
Expression.Constant(value) :> Expression
243+
else
244+
node :> Expression
245+
230246
let apply (options : ObjectListFilterLinqOptions<'T, 'D>) (filter : ObjectListFilter) (query : IQueryable<'T>) =
231247
// Helper for discriminator comparison
232248
let buildTypeDiscriminatorCheck (param : SourceExpression) (t : Type) =
@@ -239,13 +255,9 @@ module ObjectListFilter =
239255
Expression.Constant (t.FullName)
240256
) :> Expression
241257
| ValueSome discExpr, ValueNone ->
242-
Expression.Invoke (
243-
// Provided discriminator comparison
244-
discExpr,
245-
param,
246-
// Default discriminator value gathered from type
247-
Expression.Constant(t.FullName)
248-
) :> Expression
258+
// Replace parameters from the original expression with our new ones
259+
let replacer = CompareDiscriminatorExpressionVisitor (discExpr, param, t.FullName)
260+
replacer.Visit discExpr.Body
249261
| ValueNone, ValueSome discValueFn ->
250262
let discriminatorValue = discValueFn t
251263
Expression.Equal (
@@ -256,13 +268,9 @@ module ObjectListFilter =
256268
) :> Expression
257269
| ValueSome discExpr, ValueSome discValueFn ->
258270
let discriminatorValue = discValueFn t
259-
Expression.Invoke (
260-
// Provided discriminator comparison
261-
discExpr,
262-
param,
263-
// Provided discriminator value gathered from type
264-
Expression.Constant (discriminatorValue)
265-
)
271+
// Replace parameters from the original expression with our new ones
272+
let replacer = CompareDiscriminatorExpressionVisitor (discExpr, param, discriminatorValue)
273+
replacer.Visit discExpr.Body
266274
let queryExpr =
267275
let param = Expression.Parameter (typeof<'T>, "x")
268276
let body = buildFilterExpr (SourceExpression param) buildTypeDiscriminatorCheck filter

0 commit comments

Comments
 (0)