Skip to content

Commit

Permalink
Support OfType/is in in-memory database
Browse files Browse the repository at this point in the history
Part of #16963
  • Loading branch information
ajcvickers committed Aug 21, 2019
1 parent 442d3bc commit 71b35f2
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ private Expression BindProperty(Expression source, string propertyName, Type typ
}
}

var result = BindProperty(entityProjection, entityType.FindProperty(propertyName));
var property = entityType.GetRootType().GetDerivedTypesInclusive()
.Select(et => et.FindProperty(propertyName))
.FirstOrDefault(p => p != null);

var result = BindProperty(entityProjection, property);
return result.Type == type
? result
: Expression.Convert(result, type);
Expand Down Expand Up @@ -231,29 +235,36 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
&& Visit(typeBinaryExpression.Expression) is EntityProjectionExpression entityProjectionExpression)
{
var entityType = entityProjectionExpression.EntityType;

if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand))
{
return Expression.Constant(true);
}

//var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand);
//if (derivedType != null)
//{
// var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
// var discriminatorColumn = BindProperty(entityProjectionExpression, entityType.GetDiscriminatorProperty());

// return concreteEntityTypes.Count == 1
// ? _sqlExpressionFactory.Equal(discriminatorColumn,
// _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
// : (Expression)_sqlExpressionFactory.In(discriminatorColumn,
// _sqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()),
// negated: false);
//}

//return _sqlExpressionFactory.Constant(false);
var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand);
if (derivedType != null)
{
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var boundProperty = BindProperty(entityProjectionExpression, discriminatorProperty);

var equals = Expression.Equal(
boundProperty,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
boundProperty,
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}

return equals;
}
}

return null;
return Expression.Constant(false);
}

protected override Expression VisitExtension(Expression extensionExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class InMemoryQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private static readonly MethodInfo _efPropertyMethod = typeof(EF).GetTypeInfo().GetDeclaredMethod(nameof(EF.Property));

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;
private readonly InMemoryProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly IModel _model;
Expand Down Expand Up @@ -105,7 +108,8 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
return source;
}

protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException();
protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2)
=> null;

protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item)
{
Expand Down Expand Up @@ -282,7 +286,66 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour
=> TranslateScalarAggregate(source, selector, nameof(Enumerable.Min));

protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType)
=> null;
{
if (source.ShaperExpression is EntityShaperExpression entityShaperExpression)
{
var entityType = entityShaperExpression.EntityType;
if (entityType.ClrType == resultType)
{
return source;
}

var baseType = entityType.GetAllBaseTypes().SingleOrDefault(et => et.ClrType == resultType);
if (baseType != null)
{
source.ShaperExpression = entityShaperExpression.WithEntityType(baseType);

return source;
}

var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == resultType);
if (derivedType != null)
{
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var parameter = Expression.Parameter(entityType.ClrType);

var callEFProperty = Expression.Call(
_efPropertyMethod.MakeGenericMethod(
discriminatorProperty.ClrType),
parameter,
Expression.Constant(discriminatorProperty.Name));

var equals = Expression.Equal(
callEFProperty,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
callEFProperty,
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}

var predicate = Expression.Lambda(
equals,
parameter);

inMemoryQueryExpression.ServerQueryExpression = Expression.Call(
InMemoryLinqOperatorProvider.Where.MakeGenericMethod(typeof(ValueBuffer)),
inMemoryQueryExpression.ServerQueryExpression,
TranslateLambdaExpression(source, predicate));

source.ShaperExpression = entityShaperExpression.WithEntityType(derivedType);

return source;
}
}

return null;
}

protected override ShapedQueryExpression TranslateOrderBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,14 @@ public InheritanceInMemoryTest(InheritanceInMemoryFixture fixture, ITestOutputHe
//TestLoggerFactory.TestOutputHelper = testOutputHelper;
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_is_kiwi()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_is_kiwi_with_other_predicate()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Subquery_OfType()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Discriminator_used_when_projection_over_of_type()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_animal()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_bird()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_bird_first()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_bird_predicate()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_bird_with_projection()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_kiwi()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_kiwi_where_north_on_derived_property()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_kiwi_where_south_on_derived_property()
{
}

[ConditionalFact(Skip = "Issue #16963")]
public override void Can_use_of_type_rose()
{
}

[ConditionalFact(Skip = "See issue#13857")]
[ConditionalFact(Skip = "See issue#13857")] // Defining query
public override void Can_query_all_animal_views()
{
base.Can_query_all_animal_views();
}

protected override bool EnforcesFkConstraints => false;
Expand Down
25 changes: 25 additions & 0 deletions test/EFCore.Specification.Tests/Query/InheritanceTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ public virtual void Can_use_is_kiwi()
}
}

[ConditionalFact]
public virtual void Can_use_backwards_is_animal()
{
using (var context = CreateContext())
{
// ReSharper disable once IsExpressionAlwaysTrue
var kiwis = context.Set<Kiwi>().Where(a => a is Animal).ToList();

Assert.Equal(1, kiwis.Count);
}
}

[ConditionalFact]
public virtual void Can_use_is_kiwi_with_other_predicate()
{
Expand Down Expand Up @@ -191,6 +203,19 @@ public virtual void Can_use_of_type_kiwi()
}
}

[ConditionalFact(Skip = "17364")]
public virtual void Can_use_backwards_of_type_animal()
{
using (var context = CreateContext())
{
var animals = context.Set<Kiwi>().OfType<Animal>().ToList();

Assert.Equal(1, animals.Count);
Assert.IsType<Kiwi>(animals[0]);
Assert.Equal(1, context.ChangeTracker.Entries().Count());
}
}

[ConditionalFact]
public virtual void Can_use_of_type_rose()
{
Expand Down

0 comments on commit 71b35f2

Please sign in to comment.