@@ -27,6 +27,8 @@ open System.Linq.Expressions
27
27
open System.Runtime .InteropServices
28
28
open System.Reflection
29
29
30
+ type private CompareDiscriminatorExpression < 'T , 'D > = Expression< Func< 'T, 'D, bool>>
31
+
30
32
/// <summary>
31
33
/// Allows to specify discriminator comparison or discriminator getter
32
34
/// and a function that return discriminator value depending on entity type
@@ -59,7 +61,7 @@ open System.Reflection
59
61
/// </code></example>
60
62
[<Struct>]
61
63
type ObjectListFilterLinqOptions < 'T , 'D >
62
- ([< Optional>] compareDiscriminator : Expression < Func < 'T, 'D, bool > > | null , [< Optional>] getDiscriminatorValue : ( Type -> 'D) | null ) =
64
+ ([< Optional>] compareDiscriminator : CompareDiscriminatorExpression < 'T, 'D> | null , [< Optional>] getDiscriminatorValue : ( Type -> 'D) | null ) =
63
65
64
66
member _.CompareDiscriminator = compareDiscriminator |> ValueOption.ofObj
65
67
member _.GetDiscriminatorValue = getDiscriminatorValue |> ValueOption.ofObj
@@ -74,7 +76,7 @@ type ObjectListFilterLinqOptions<'T, 'D>
74
76
75
77
new ( getDiscriminator : Expression < Func < 'T , 'D >>) =
76
78
ObjectListFilterLinqOptions< 'T, 'D> ( ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, null )
77
- new ( compareDiscriminator : Expression < Func < 'T , 'D , bool > >) = ObjectListFilterLinqOptions< 'T, 'D> ( compareDiscriminator, null )
79
+ new ( compareDiscriminator : CompareDiscriminatorExpression < 'T , 'D >) = ObjectListFilterLinqOptions< 'T, 'D> ( compareDiscriminator, null )
78
80
new ( getDiscriminatorValue : Type -> 'D ) =
79
81
ObjectListFilterLinqOptions< 'T, 'D> ( compareDiscriminator = null , getDiscriminatorValue = getDiscriminatorValue)
80
82
new ( getDiscriminator : Expression < Func < 'T , 'D >>, getDiscriminatorValue : Type -> 'D ) =
@@ -227,6 +229,20 @@ module ObjectListFilter =
227
229
let paramExpr = Expression.PropertyOrField ( param, f.FieldName)
228
230
buildFilterExpr ( SourceExpression paramExpr) buildTypeDiscriminatorCheck f.Value
229
231
232
+ type private CompareDiscriminatorExpressionVisitor < 'T , 'D > (
233
+ compareDiscriminator : CompareDiscriminatorExpression< 'T, 'D>,
234
+ param : SourceExpression,
235
+ value : obj
236
+ ) =
237
+ inherit ExpressionVisitor ()
238
+ override _.VisitParameter ( node ) =
239
+ if node = compareDiscriminator.Parameters.[ 0 ] then
240
+ param.Value
241
+ elif node = compareDiscriminator.Parameters.[ 1 ] then
242
+ Expression.Constant( value) :> Expression
243
+ else
244
+ node :> Expression
245
+
230
246
let apply ( options : ObjectListFilterLinqOptions < 'T , 'D >) ( filter : ObjectListFilter ) ( query : IQueryable < 'T >) =
231
247
// Helper for discriminator comparison
232
248
let buildTypeDiscriminatorCheck ( param : SourceExpression ) ( t : Type ) =
@@ -239,13 +255,9 @@ module ObjectListFilter =
239
255
Expression.Constant ( t.FullName)
240
256
) :> Expression
241
257
| ValueSome discExpr, ValueNone ->
242
- Expression.Invoke (
243
- // Provided discriminator comparison
244
- discExpr,
245
- param,
246
- // Default discriminator value gathered from type
247
- Expression.Constant( t.FullName)
248
- ) :> Expression
258
+ // Replace parameters from the original expression with our new ones
259
+ let replacer = CompareDiscriminatorExpressionVisitor ( discExpr, param, t.FullName)
260
+ replacer.Visit discExpr.Body
249
261
| ValueNone, ValueSome discValueFn ->
250
262
let discriminatorValue = discValueFn t
251
263
Expression.Equal (
@@ -256,13 +268,9 @@ module ObjectListFilter =
256
268
) :> Expression
257
269
| ValueSome discExpr, ValueSome discValueFn ->
258
270
let discriminatorValue = discValueFn t
259
- Expression.Invoke (
260
- // Provided discriminator comparison
261
- discExpr,
262
- param,
263
- // Provided discriminator value gathered from type
264
- Expression.Constant ( discriminatorValue)
265
- )
271
+ // Replace parameters from the original expression with our new ones
272
+ let replacer = CompareDiscriminatorExpressionVisitor ( discExpr, param, discriminatorValue)
273
+ replacer.Visit discExpr.Body
266
274
let queryExpr =
267
275
let param = Expression.Parameter ( typeof< 'T>, " x" )
268
276
let body = buildFilterExpr ( SourceExpression param) buildTypeDiscriminatorCheck filter
0 commit comments