Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reference Include for in-memory provider #17011

Merged
merged 1 commit into from
Aug 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Reference Include for in-memory provider
Part of #16963
  • Loading branch information
ajcvickers committed Aug 7, 2019
commit 06461a1840eb4b37573b15bd2cba2eb343a49305
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public partial class InMemoryShapedQueryCompilingExpressionVisitor
{
private class CustomShaperCompilingExpressionVisitor : ExpressionVisitor
{
private readonly bool _tracking;

public CustomShaperCompilingExpressionVisitor(bool tracking)
{
_tracking = tracking;
}

private static readonly MethodInfo _includeReferenceMethodInfo
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(IncludeReference));

private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>(
QueryContext queryContext,
TEntity entity,
TIncludedEntity relatedEntity,
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : TEntity
{
if (entity is TIncludingEntity includingEntity)
{
if (trackingQuery)
{
// For non-null relatedEntity StateManager will set the flag
if (relatedEntity == null)
{
queryContext.StateManager.TryGetEntry(includingEntity).SetIsLoaded(navigation);
}
}
else
{
SetIsLoadedNoTracking(includingEntity, navigation);
if (relatedEntity != null)
{
fixup(includingEntity, relatedEntity);
if (inverseNavigation != null
&& !inverseNavigation.IsCollection())
{
SetIsLoadedNoTracking(relatedEntity, inverseNavigation);
}
}
}
}
}

private static void SetIsLoadedNoTracking(object entity, INavigation navigation)
=> ((ILazyLoader)(navigation
.DeclaringEntityType
.GetServiceProperties()
.FirstOrDefault(p => p.ClrType == typeof(ILazyLoader)))
?.GetGetter().GetClrValue(entity))
?.SetLoaded(entity, navigation.Name);

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is IncludeExpression includeExpression)
{
Debug.Assert(
!includeExpression.Navigation.IsCollection(),
"Only reference include should be present in tree");

var entityClrType = includeExpression.EntityExpression.Type;
var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType;
var inverseNavigation = includeExpression.Navigation.FindInverse();
var relatedEntityClrType = includeExpression.Navigation.GetTargetType().ClrType;
if (includingClrType != entityClrType
&& includingClrType.IsAssignableFrom(entityClrType))
{
includingClrType = entityClrType;
}

return Expression.Call(
_includeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType),
QueryCompilationContext.QueryContextParameter,
// We don't need to visit entityExpression since it is supposed to be a parameterExpression only
includeExpression.EntityExpression,
includeExpression.NavigationExpression,
Expression.Constant(includeExpression.Navigation),
Expression.Constant(inverseNavigation, typeof(INavigation)),
Expression.Constant(
GenerateFixup(
includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()),
Expression.Constant(_tracking));
}

return base.VisitExtension(extensionExpression);
}

private static LambdaExpression GenerateFixup(
Type entityType,
Type relatedEntityType,
INavigation navigation,
INavigation inverseNavigation)
{
var entityParameter = Expression.Parameter(entityType);
var relatedEntityParameter = Expression.Parameter(relatedEntityType);
var expressions = new List<Expression>
{
navigation.IsCollection()
? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation)
: AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation)
};

if (inverseNavigation != null)
{
expressions.Add(
inverseNavigation.IsCollection()
? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation)
: AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation));

}

return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter);
}

private static Expression AssignReferenceNavigation(
ParameterExpression entity,
ParameterExpression relatedEntity,
INavigation navigation)
{
return entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity);
}

private static Expression AddToCollectionNavigation(
ParameterExpression entity,
ParameterExpression relatedEntity,
INavigation navigation)
=> Expression.Call(
Expression.Constant(navigation.GetCollectionAccessor()),
_collectionAccessorAddMethodInfo,
entity,
relatedEntity,
Expression.Constant(true));

private static readonly MethodInfo _collectionAccessorAddMethodInfo
= typeof(IClrCollectionAccessor).GetTypeInfo()
.GetDeclaredMethod(nameof(IClrCollectionAccessor.Add));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ protected override Expression VisitExtension(Expression extensionExpression)
}
}

if (extensionExpression is IncludeExpression includeExpression)
{
return _clientEval
? base.VisitExtension(includeExpression)
: null;
}

throw new InvalidOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class InMemoryShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingExpressionVisitor
public partial class InMemoryShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingExpressionVisitor
{
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;
private static readonly ConstructorInfo _valueBufferConstructor
= typeof(ValueBuffer).GetConstructors().Single(ci => ci.GetParameters().Length == 1);

public InMemoryShapedQueryCompilingExpressionVisitor(
QueryCompilationContext queryCompilationContext,
Expand Down Expand Up @@ -53,19 +51,20 @@ protected override Expression VisitExtension(Expression extensionExpression)

protected override Expression VisitShapedQueryExpression(ShapedQueryExpression shapedQueryExpression)
{
var shaperBody = InjectEntityMaterializers(shapedQueryExpression.ShaperExpression);
var inMemoryQueryExpression = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression;

var innerEnumerable = Visit(shapedQueryExpression.QueryExpression);
var shaper = new ShaperExpressionProcessingExpressionVisitor(inMemoryQueryExpression)
.Inject(shapedQueryExpression.ShaperExpression);

var inMemoryQueryExpression = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression;
shaper = InjectEntityMaterializers(shaper);

var newBody = new InMemoryProjectionBindingRemovingExpressionVisitor(inMemoryQueryExpression)
.Visit(shaperBody);
var innerEnumerable = Visit(inMemoryQueryExpression);

var shaperLambda = Expression.Lambda(
newBody,
QueryCompilationContext.QueryContextParameter,
inMemoryQueryExpression.ValueBufferParameter);
shaper = new InMemoryProjectionBindingRemovingExpressionVisitor(inMemoryQueryExpression).Visit(shaper);

shaper = new CustomShaperCompilingExpressionVisitor(IsTracking).Visit(shaper);

var shaperLambda = (LambdaExpression)shaper;

return Expression.New(
(IsAsync
Expand Down Expand Up @@ -263,6 +262,7 @@ public ValueTask DisposeAsync()
private class InMemoryProjectionBindingRemovingExpressionVisitor : ExpressionVisitor
{
private readonly InMemoryQueryExpression _queryExpression;

private readonly IDictionary<ParameterExpression, IDictionary<IProperty, int>> _materializationContextBindings
= new Dictionary<ParameterExpression, IDictionary<IProperty, int>>();

Expand All @@ -284,7 +284,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
_materializationContextBindings[parameterExpression]
= (IDictionary<IProperty, int>)GetProjectionIndex(projectionBindingExpression);

var updatedExpression = Expression.New(newExpression.Constructor,
var updatedExpression = Expression.New(
newExpression.Constructor,
Expression.Constant(ValueBuffer.Empty),
newExpression.Arguments[1]);

Expand All @@ -300,7 +301,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod)
{
var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value;
var indexMap = _materializationContextBindings[(ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object];
var indexMap =
_materializationContextBindings[
(ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object];

return Expression.Call(
methodCallExpression.Method,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class ShaperExpressionProcessingExpressionVisitor : ExpressionVisitor
{
private readonly InMemoryQueryExpression _queryExpression;

private readonly IDictionary<Expression, ParameterExpression> _mapping = new Dictionary<Expression, ParameterExpression>();
private readonly List<ParameterExpression> _variables = new List<ParameterExpression>();
private readonly List<Expression> _expressions = new List<Expression>();

public ShaperExpressionProcessingExpressionVisitor(
InMemoryQueryExpression queryExpression)
{
_queryExpression = queryExpression;
}

public virtual Expression Inject(Expression expression)
{
var result = Visit(expression);

if (_expressions.All(e => e.NodeType == ExpressionType.Assign))
{
result = new ReplacingExpressionVisitor(_expressions.Cast<BinaryExpression>()
.ToDictionary(e => e.Left, e => e.Right)).Visit(result);
}
else
{
_expressions.Add(result);
result = Expression.Block(_variables, _expressions);
}

return ConvertToLambda(result, Expression.Parameter(result.Type, "result"));
}

private LambdaExpression ConvertToLambda(Expression result, ParameterExpression resultParameter)
=> Expression.Lambda(
result,
QueryCompilationContext.QueryContextParameter,
_queryExpression.ValueBufferParameter);

protected override Expression VisitExtension(Expression extensionExpression)
{
switch (extensionExpression)
{
case EntityShaperExpression entityShaperExpression:
{
var key = GenerateKey((ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression);
if (!_mapping.TryGetValue(key, out var variable))
{
variable = Expression.Parameter(entityShaperExpression.EntityType.ClrType);
_variables.Add(variable);
_expressions.Add(Expression.Assign(variable, entityShaperExpression));
_mapping[key] = variable;
}

return variable;
}

case ProjectionBindingExpression projectionBindingExpression:
{
var key = GenerateKey(projectionBindingExpression);
if (!_mapping.TryGetValue(key, out var variable))
{
variable = Expression.Parameter(projectionBindingExpression.Type);
_variables.Add(variable);
_expressions.Add(Expression.Assign(variable, projectionBindingExpression));
_mapping[key] = variable;
}

return variable;
}

case IncludeExpression includeExpression:
{
var entity = Visit(includeExpression.EntityExpression);
_expressions.Add(
includeExpression.Update(
entity,
Visit(includeExpression.NavigationExpression)));

return entity;
}
}

return base.VisitExtension(extensionExpression);
}

private Expression GenerateKey(ProjectionBindingExpression projectionBindingExpression)
=> projectionBindingExpression.ProjectionMember != null
? _queryExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: projectionBindingExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ public class InMemoryComplianceTest : ComplianceTestBase
typeof(ComplexNavigationsQueryTestBase<>),
typeof(GearsOfWarQueryTestBase<>),
typeof(IncludeAsyncTestBase<>),
typeof(IncludeOneToOneTestBase<>),
typeof(IncludeTestBase<>),
typeof(InheritanceRelationshipsQueryTestBase<>),
typeof(InheritanceTestBase<>),
typeof(NullKeysTestBase<>),
Expand Down
Loading