diff --git a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs index 7905218..cd28cfa 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs @@ -1,168 +1,169 @@ using System.Reflection; -namespace EntityFrameworkCore.Projectables.Extensions +namespace EntityFrameworkCore.Projectables.Extensions; + +public static class TypeExtensions { - public static class TypeExtensions + public static Type[] GetNestedTypePath(this Type type) { - public static string GetSimplifiedTypeName(this Type type) + // First pass: count the nesting depth so we can size the array exactly. + var depth = 0; + var current = type; + while (true) { - var name = type.Name; - - var backtickIndex = name.IndexOf("`"); - if (backtickIndex != -1) + depth++; + if (!current.IsNested || current.DeclaringType is null) { - name = name.Substring(0, backtickIndex); + break; } - return name; + current = current.DeclaringType; } - public static IEnumerable GetNestedTypePath(this Type type) + // Second pass: fill the array outermost-first by walking back from the leaf. + var path = new Type[depth]; + current = type; + for (var i = depth - 1; i >= 0; i--) { - if (type.IsNested && type.DeclaringType is not null) - { - foreach (var containingType in type.DeclaringType.GetNestedTypePath()) - { - yield return containingType; - } - } - - yield return type; + path[i] = current; + current = current.DeclaringType!; } - private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo) - { - // We only need to search for virtual instance methods who are not declared on the derivedType - if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual) - { - return false; - } + return path; + } - if (!derivedType.IsAssignableTo(methodInfo.DeclaringType)) - { - throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); - } + private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo) + { + // We only need to search for virtual instance methods who are not declared on the derivedType + if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual) + { + return false; + } - return true; + if (!derivedType.IsAssignableTo(methodInfo.DeclaringType)) + { + throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); } - private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition) - => methodInfo.GetBaseDefinition() == baseDefinition; + return true; + } + + private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition) + => methodInfo.GetBaseDefinition() == baseDefinition; - public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) + public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) + { + if (!derivedType.CanHaveOverridingMethod(methodInfo)) { - if (!derivedType.CanHaveOverridingMethod(methodInfo)) - { - return methodInfo; - } + return methodInfo; + } - var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - MethodInfo? overridingMethod = null; - if (derivedMethods is { Length: > 0 }) - { - var baseDefinition = methodInfo.GetBaseDefinition(); - overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo - => derivedMethodInfo.IsOverridingMethodOf(baseDefinition)); - } - - return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo + MethodInfo? overridingMethod = null; + if (derivedMethods is { Length: > 0 }) + { + var baseDefinition = methodInfo.GetBaseDefinition(); + overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo + => derivedMethodInfo.IsOverridingMethodOf(baseDefinition)); } - public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) + return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo + } + + private static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) + { + var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod); + if (accessor is null) { - var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod); - if (accessor is null) - { - return propertyInfo; - } + return propertyInfo; + } - var isGetAccessor = propertyInfo.GetMethod == accessor; + var isGetAccessor = propertyInfo.GetMethod == accessor; - var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - - PropertyInfo? overridingProperty = null; - if (derivedProperties is { Length: > 0 }) - { - var baseDefinition = accessor.GetBaseDefinition(); - overridingProperty = derivedProperties.FirstOrDefault(p - => (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true); - } - - return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo - } + var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) + PropertyInfo? overridingProperty = null; + if (derivedProperties is { Length: > 0 }) { - var interfaceType = methodInfo.DeclaringType; - // We only need to search for interface methods - if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual) - { - return methodInfo; - } - - if (!derivedType.IsAssignableTo(interfaceType)) - { - throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); - } + var baseDefinition = accessor.GetBaseDefinition(); + overridingProperty = derivedProperties.FirstOrDefault(p + => (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true); + } - var interfaceMap = derivedType.GetInterfaceMap(interfaceType); - for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++) - { - if (interfaceMap.InterfaceMethods[i] == methodInfo) - { - return interfaceMap.TargetMethods[i]; - } - } + return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo + } - throw new ApplicationException( - $"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!"); + private static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) + { + var interfaceType = methodInfo.DeclaringType; + // We only need to search for interface methods + if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual) + { + return methodInfo; } - public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo) + if (!derivedType.IsAssignableTo(interfaceType)) { - var accessor = propertyInfo.GetAccessors()[0]; + throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); + } - var implementingAccessor = derivedType.GetImplementingMethod(accessor); - if (implementingAccessor == accessor) + var interfaceMap = derivedType.GetInterfaceMap(interfaceType); + for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++) + { + if (interfaceMap.InterfaceMethods[i] == methodInfo) { - return propertyInfo; + return interfaceMap.TargetMethods[i]; } + } - var implementingType = implementingAccessor.DeclaringType - // This should only be null if it is a property accessor on the global module, - // which should never happen since we found it from derivedType - ?? throw new ApplicationException("The property accessor has no declaring type!"); + throw new ApplicationException( + $"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!"); + } - var derivedProperties = implementingType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo) + { + var accessor = propertyInfo.GetAccessors()[0]; - return derivedProperties.FirstOrDefault(propertyInfo.GetMethod == accessor - ? p => MethodInfosEqual(p.GetMethod, implementingAccessor) - : p => MethodInfosEqual(p.SetMethod, implementingAccessor)) ?? propertyInfo; + var implementingAccessor = derivedType.GetImplementingMethod(accessor); + if (implementingAccessor == accessor) + { + return propertyInfo; } - /// - /// The built-in - /// does not work if the s don't agree. - /// - private static bool MethodInfosEqual(MethodInfo? first, MethodInfo second) - => first?.ReflectedType == second.ReflectedType - ? first == second - : first is not null - && first.DeclaringType == second.DeclaringType - && first.Name == second.Name - && first.GetParameters().Select(p => p.ParameterType) - .SequenceEqual(second.GetParameters().Select(p => p.ParameterType)) - && first.GetGenericArguments().SequenceEqual(second.GetGenericArguments()); - - public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo) - => methodInfo.DeclaringType?.IsInterface == true - ? derivedType.GetImplementingMethod(methodInfo) - : derivedType.GetOverridingMethod(methodInfo); - - public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo) - => propertyInfo.DeclaringType?.IsInterface == true - ? derivedType.GetImplementingProperty(propertyInfo) - : derivedType.GetOverridingProperty(propertyInfo); + var implementingType = implementingAccessor.DeclaringType + // This should only be null if it is a property accessor on the global module, + // which should never happen since we found it from derivedType + ?? throw new ApplicationException("The property accessor has no declaring type!"); + + var derivedProperties = implementingType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + return derivedProperties.FirstOrDefault(propertyInfo.GetMethod == accessor + ? p => MethodInfosEqual(p.GetMethod, implementingAccessor) + : p => MethodInfosEqual(p.SetMethod, implementingAccessor)) ?? propertyInfo; } + + /// + /// The built-in + /// does not work if the s don't agree. + /// + private static bool MethodInfosEqual(MethodInfo? first, MethodInfo second) + => first?.ReflectedType == second.ReflectedType + ? first == second + : first is not null + && first.DeclaringType == second.DeclaringType + && first.Name == second.Name + && first.GetParameters().Select(p => p.ParameterType) + .SequenceEqual(second.GetParameters().Select(p => p.ParameterType)) + && first.GetGenericArguments().SequenceEqual(second.GetGenericArguments()); + + public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo) + => methodInfo.DeclaringType?.IsInterface == true + ? derivedType.GetImplementingMethod(methodInfo) + : derivedType.GetOverridingMethod(methodInfo); + + public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo) + => propertyInfo.DeclaringType?.IsInterface == true + ? derivedType.GetImplementingProperty(propertyInfo) + : derivedType.GetOverridingProperty(propertyInfo); } diff --git a/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionResolver.cs index 64f73c7..087bad3 100644 --- a/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionResolver.cs @@ -1,10 +1,10 @@ using System.Linq.Expressions; using System.Reflection; -namespace EntityFrameworkCore.Projectables.Services +namespace EntityFrameworkCore.Projectables.Services; + +public interface IProjectionExpressionResolver { - public interface IProjectionExpressionResolver - { - LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo); - } + LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null); } \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 3bff25a..72bc0f8 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -51,7 +51,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La var projectableAttribute = memberInfo.GetCustomAttribute(false); reflectedExpression = projectableAttribute is not null - ? _resolver.FindGeneratedExpression(memberInfo) + ? _resolver.FindGeneratedExpression(memberInfo, projectableAttribute) : null; _projectableMemberCache.Add(memberInfo, reflectedExpression); diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionClassNameGenerator.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionClassNameGenerator.cs index d21c8ce..ee30ce6 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionClassNameGenerator.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionClassNameGenerator.cs @@ -37,7 +37,15 @@ public static string GenerateFullName(string? namespaceName, IEnumerable static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceName, IEnumerable? nestedInClassNames, string memberName, IEnumerable? parameterTypeNames) { - stringBuilder.Append(namespaceName?.Replace('.', '_')); + // Append namespace, replacing '.' separators with '_' in a single pass (no intermediate string). + if (namespaceName is not null) + { + foreach (var c in namespaceName) + { + stringBuilder.Append(c == '.' ? '_' : c); + } + } + stringBuilder.Append('_'); var arity = 0; @@ -65,7 +73,16 @@ static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceNam } } - stringBuilder.Append(memberName.Replace(".", "__")); // Support explicit interface implementations + + // Append member name; only allocate a replacement string for the rare explicit-interface case. + if (memberName.IndexOf('.') >= 0) + { + stringBuilder.Append(memberName.Replace(".", "__")); + } + else + { + stringBuilder.Append(memberName); + } // Add parameter types to make method overloads unique if (parameterTypeNames is not null) @@ -76,20 +93,9 @@ static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceNam stringBuilder.Append("_P"); stringBuilder.Append(parameterIndex); stringBuilder.Append('_'); - // Replace characters that are not valid in type names with underscores - var sanitizedTypeName = parameterTypeName - .Replace("global::", "") // Remove global:: prefix - .Replace('.', '_') - .Replace('<', '_') - .Replace('>', '_') - .Replace(',', '_') - .Replace(' ', '_') - .Replace('[', '_') - .Replace(']', '_') - .Replace('`', '_') - .Replace(':', '_') // Additional safety for any remaining colons - .Replace('?', '_'); // Handle nullable reference types - stringBuilder.Append(sanitizedTypeName); + // Single-pass sanitization: replace invalid identifier characters with '_', + // stripping the "global::" prefix on the fly — avoids 9 intermediate string allocations. + AppendSanitizedTypeName(stringBuilder, parameterTypeName); parameterIndex++; } } @@ -104,5 +110,26 @@ static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceNam return stringBuilder.ToString(); } + + /// + /// Appends to , stripping the + /// global:: prefix and replacing every character that is invalid in a C# identifier + /// with '_' — all in a single pass with no intermediate string allocations. + /// + private static void AppendSanitizedTypeName(StringBuilder sb, string typeName) + { + const string GlobalPrefix = "global::"; + var start = typeName.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? GlobalPrefix.Length : 0; + + for (var i = start; i < typeName.Length; i++) + { + var c = typeName[i]; + sb.Append(IsInvalidIdentifierChar(c) ? '_' : c); + } + } + + private static bool IsInvalidIdentifierChar(char c) => + c == '.' || c == '<' || c == '>' || c == ',' || c == ' ' || + c == '[' || c == ']' || c == '`' || c == ':' || c == '?'; } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 8fe4bc4..d00daad 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Concurrent; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.CompilerServices; using EntityFrameworkCore.Projectables.Extensions; namespace EntityFrameworkCore.Projectables.Services @@ -13,6 +15,42 @@ public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver private readonly static Func _nullRegistry = static _ => null!; private readonly static ConcurrentDictionary> _assemblyRegistries = new(); + /// + /// Caches the fully-resolved per so that + /// EF Core never repeats reflection work for the same member across queries. + /// + private readonly static ConcurrentDictionary _expressionCache = new(); + + /// + /// Caches → C#-formatted name strings, since the same parameter types + /// appear repeatedly across different projectable members. + /// + private readonly static ConditionalWeakTable _typeNameCache = new(); + + /// + /// O(1) hash-table lookup replacing the original 16 sequential if checks. + /// Rearranging the entries has no effect on lookup cost (hash-based), but the most common + /// EF Core types (int, string, bool) are listed first for readability. + /// + private readonly static Dictionary _csharpKeywords = new(16) + { + [typeof(int)] = "int", + [typeof(string)] = "string", + [typeof(bool)] = "bool", + [typeof(long)] = "long", + [typeof(double)] = "double", + [typeof(decimal)] = "decimal", + [typeof(float)] = "float", + [typeof(byte)] = "byte", + [typeof(sbyte)] = "sbyte", + [typeof(char)] = "char", + [typeof(uint)] = "uint", + [typeof(ulong)] = "ulong", + [typeof(short)] = "short", + [typeof(ushort)] = "ushort", + [typeof(object)] = "object", + }; + /// /// Looks up the generated ProjectionRegistry class in an assembly (once, then caches it). /// Returns a delegate that calls TryGet(MemberInfo) on the registry, or null if the registry @@ -38,9 +76,15 @@ public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver return ReferenceEquals(registry, _nullRegistry) ? null : registry; } - public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null) + => _expressionCache.GetOrAdd(projectableMemberInfo, static (mi, a) => ResolveExpressionCore(mi, a), + projectableAttribute); + + private static LambdaExpression ResolveExpressionCore(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null) { - var projectableAttribute = projectableMemberInfo.GetCustomAttribute() + projectableAttribute ??= projectableMemberInfo.GetCustomAttribute() ?? throw new InvalidOperationException("Expected member to have a Projectable attribute. None found"); var expression = GetExpressionFromGeneratedType(projectableMemberInfo); @@ -100,30 +144,28 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo case MethodInfo method: { var methodParams = method.GetParameters(); - + // The lambda's return type must match the method's return type. - if (lambda.ReturnType != method.ReturnType) + if (lambda.ReturnType != method.ReturnType) { return null; } - + if (method.IsStatic) { // Static methods (including extension methods): all parameters are explicit. - // The lambda maps directly to the method's parameter list — no implicit 'this'. if (lambda.Parameters.Count == methodParams.Length && - !lambda.Parameters.Zip(methodParams, (a, b) => a.Type != b.ParameterType).Any()) + ParameterTypesMatch(lambda.Parameters, 0, methodParams)) { return lambda; } } else { - // Instance methods: the lambda's first parameter is the implicit 'this' (the declaring - // type), followed by the explicit method parameters — i.e. (@this, arg1, arg2, ...) => ... + // Instance methods: lambda's first parameter is the implicit 'this'. if (lambda.Parameters.Count == methodParams.Length + 1 && lambda.Parameters[0].Type == declaringType && - !lambda.Parameters.Skip(1).Zip(methodParams, (a, b) => a.Type != b.ParameterType).Any()) + ParameterTypesMatch(lambda.Parameters, 1, methodParams)) { return lambda; } @@ -136,6 +178,41 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo return null; } + /// + /// Compares lambda parameter types against method parameter types starting at , + /// avoiding LINQ allocations (no Zip/Any enumerators or delegates). + /// + private static bool ParameterTypesMatch( + ReadOnlyCollection lambdaParams, + int offset, + ParameterInfo[] methodParams) + { + for (var i = 0; i < methodParams.Length; i++) + { + if (lambdaParams[offset + i].Type != methodParams[i].ParameterType) + { + return false; + } + } + + return true; + } + + /// + /// Sentinel stored in to represent + /// "no generated type found for this member", distinguishing it from a not-yet-populated entry. + /// does not allow null values, so a sentinel is required. + /// + private readonly static Func _reflectionNotFoundSentinel = static () => null!; + + /// + /// Caches a pre-compiled Func<LambdaExpression> delegate per + /// so that Assembly.GetType, GetMethod, MakeGenericType, and + /// MakeGenericMethod are only paid once per member. All subsequent calls execute + /// native JIT-compiled code with zero reflection overhead. + /// + private readonly static ConcurrentDictionary> _reflectionFactoryCache = new(); + /// /// Resolves the for a [Projectable] member using the /// reflection-based slow path only, bypassing the static registry. @@ -143,75 +220,134 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo /// public static LambdaExpression? FindGeneratedExpressionViaReflection(MemberInfo projectableMemberInfo) { - var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + var factory = _reflectionFactoryCache.GetOrAdd(projectableMemberInfo, static mi => BuildReflectionFactory(mi)); + return ReferenceEquals(factory, _reflectionNotFoundSentinel) ? null : factory.Invoke(); + } + + /// + /// Performs the one-time reflection work for a member and returns a compiled native delegate + /// (or if no generated type exists). + /// + /// We use Expression.Lambda<TDelegate>(...).Compile() rather than + /// Delegate.CreateDelegate because the generated Expression() factory method + /// returns Expression<TDelegate> (a subtype of ), and + /// CreateDelegate requires an exact return-type match in most runtime environments. + /// The expression-tree wrapper handles the covariant cast cleanly and compiles to native code. + /// + /// + private static Func BuildReflectionFactory(MemberInfo projectableMemberInfo) + { + var declaringType = projectableMemberInfo.DeclaringType + ?? throw new InvalidOperationException("Expected a valid type here"); - // Keep track of the original declaring type's generic arguments for later use var originalDeclaringType = declaringType; - // For generic types, use the generic type definition to match the generated name - // which is based on the open generic type + // For generic types, use the generic type definition to match the generated name. if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) { declaringType = declaringType.GetGenericTypeDefinition(); } - // Get parameter types for method overload disambiguation - // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat - // which uses C# keywords for primitive types (int, string, etc.) + // Build parameter type name array with a plain for-loop — avoids IEnumerator + delegate allocations. string[]? parameterTypeNames = null; - string memberLookupName = projectableMemberInfo.Name; + var memberLookupName = projectableMemberInfo.Name; + if (projectableMemberInfo is MethodInfo method) { - // For generic methods, use the generic definition to get parameter types - // This ensures type parameters like TEntity are used instead of concrete types + // For generic methods, use the generic definition so type parameters (TEntity, etc.) + // are used instead of the concrete closed-generic arguments. var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; + var parameters = methodToInspect.GetParameters(); - parameterTypeNames = methodToInspect.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); + if (parameters.Length > 0) + { + parameterTypeNames = new string[parameters.Length]; + for (var i = 0; i < parameters.Length; i++) + { + parameterTypeNames[i] = GetFullTypeName(parameters[i].ParameterType); + } + } } else if (projectableMemberInfo is ConstructorInfo ctor) { - // Constructors are stored under the synthetic name "_ctor" + // Constructors are stored under the synthetic name "_ctor". memberLookupName = "_ctor"; - parameterTypeNames = ctor.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); + var parameters = ctor.GetParameters(); + + if (parameters.Length > 0) + { + parameterTypeNames = new string[parameters.Length]; + for (var i = 0; i < parameters.Length; i++) + { + parameterTypeNames[i] = GetFullTypeName(parameters[i].ParameterType); + } + } } - var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); + // GetNestedTypePath() returns a Type[] — project to string[] with a direct loop, no LINQ Select. + var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName( + declaringType.Namespace, + NestedTypePathToNames(declaringType.GetNestedTypePath()), + memberLookupName, + parameterTypeNames); var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); - if (expressionFactoryType is not null) + if (expressionFactoryType is null) { - if (expressionFactoryType.IsGenericTypeDefinition) - { - expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); - } + return _reflectionNotFoundSentinel; + } - var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (expressionFactoryType.IsGenericTypeDefinition) + { + expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); + } - var methodGenericArguments = projectableMemberInfo switch { - MethodInfo methodInfo => methodInfo.GetGenericArguments(), - _ => null - }; + var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - if (expressionFactoryMethod is not null) - { - if (methodGenericArguments is { Length: > 0 }) - { - expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); - } + if (expressionFactoryMethod is null) + { + return _reflectionNotFoundSentinel; + } - return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); - } + if (projectableMemberInfo is MethodInfo mi && mi.GetGenericArguments() is { Length: > 0 } methodGenericArgs) + { + expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArgs); } - return null; + // Compile a native delegate: () => (LambdaExpression)GeneratedClass.Expression() + // Expression.Call + Convert handles the covariant return type (Expression → LambdaExpression). + // The one-time Compile() cost is amortized; all subsequent calls are direct native-code invocations. + var call = Expression.Call(expressionFactoryMethod); + var cast = Expression.Convert(call, typeof(LambdaExpression)); + return Expression.Lambda>(cast).Compile(); + } + + /// + /// Projects an array of objects — in practice always the + /// Type[] returned by — to a string[] + /// of simple type names without allocating a LINQ enumerator or intermediate delegate. + /// + private static string[] NestedTypePathToNames(Type[] types) + { + var names = new string[types.Length]; + for (var i = 0; i < types.Length; i++) + { + names[i] = types[i].Name; + } + + return names; } + /// + /// Returns the C#-formatted full name of . + /// Results are memoised in ; the same object + /// is encountered repeatedly across projectable members (e.g. int, string). + /// private static string GetFullTypeName(Type type) + => _typeNameCache.GetValue(type, static t => ComputeFullTypeName(t)); + + private static string ComputeFullTypeName(Type type) { // Handle generic type parameters (e.g., T, TEntity) if (type.IsGenericParameter) @@ -284,24 +420,13 @@ private static string GetFullTypeName(Type type) return type.Name; } - private static string? GetCSharpKeyword(Type type) - { - if (type == typeof(bool)) return "bool"; - if (type == typeof(byte)) return "byte"; - if (type == typeof(sbyte)) return "sbyte"; - if (type == typeof(char)) return "char"; - if (type == typeof(decimal)) return "decimal"; - if (type == typeof(double)) return "double"; - if (type == typeof(float)) return "float"; - if (type == typeof(int)) return "int"; - if (type == typeof(uint)) return "uint"; - if (type == typeof(long)) return "long"; - if (type == typeof(ulong)) return "ulong"; - if (type == typeof(short)) return "short"; - if (type == typeof(ushort)) return "ushort"; - if (type == typeof(object)) return "object"; - if (type == typeof(string)) return "string"; - return null; - } + /// + /// O(1) dictionary lookup — replaces the original 16 sequential if checks. + /// Note: reordering the entries in has no effect on + /// performance because uses hashing, not linear scan. + /// (Reordering only mattered with the old if-chain, where placing int / string + /// / bool first would have reduced average comparisons from ~8 to ~1.) + /// + private static string? GetCSharpKeyword(Type type) => _csharpKeywords.GetValueOrDefault(type); } } diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs index 3637cbf..2732f44 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs @@ -60,7 +60,7 @@ public void GetNestedTypePath_InnerType_Returns2Entries() var result = subject.GetNestedTypePath(); - Assert.Equal(2, result.Count()); + Assert.Equal(2, result.Length); } [Fact] @@ -70,7 +70,7 @@ public void GetNestedTypePath_SubsequentlyInnerType_Returns3Entries() var result = subject.GetNestedTypePath(); - Assert.Equal(3, result.Count()); + Assert.Equal(3, result.Length); } [Fact] diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs index 9fe3ec9..19d5839 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs @@ -15,14 +15,15 @@ public class ProjectableExpressionReplacerTests { public class ProjectableExpressionResolverStub : IProjectionExpressionResolver { - readonly Func _implementation; + readonly Func _implementation; - public ProjectableExpressionResolverStub(Func implementation) + public ProjectableExpressionResolverStub(Func implementation) { _implementation = implementation; } - public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) => _implementation(projectableMemberInfo); + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null) => _implementation(projectableMemberInfo, projectableAttribute); } class Entity @@ -58,7 +59,7 @@ public void VisitMember_SimpleProperty() Expression> expected = x => 0; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -74,7 +75,7 @@ public void VisitMember_SimpleMethod() Expression> expected = x => 0; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -90,7 +91,7 @@ public void VisitMember_SimpleMethodWithArguments() Expression> expected = x => 0; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -106,7 +107,7 @@ public void VisitMember_SimpleStatefullProperty() Expression> expected = x => x.Id; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -122,7 +123,7 @@ public void VisitMember_SimpleStatefullMethod() Expression> expected = x => x.Id; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -138,7 +139,7 @@ public void VisitMember_SimpleStaticMethod() Expression> expected = x => 0; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -154,7 +155,7 @@ public void VisitMember_SimpleStaticMethodWithArguments() Expression> expected = x => 0; var resolver = new ProjectableExpressionResolverStub( - x => expected + (x, a) => expected ); var subject = new ProjectableExpressionReplacer(resolver); @@ -188,7 +189,7 @@ public void VisitMember_CompilerGeneratedClosure_PropertyInfoBranch_FallsThrough var memberAccess = Expression.MakeMemberAccess(closureConst, propertyInfo); var resolver = new ProjectableExpressionResolverStub( - _ => throw new InvalidOperationException("Resolver should not be called for non-projectable members.") + (x, a) => throw new InvalidOperationException("Resolver should not be called for non-projectable members.") ); var subject = new ProjectableExpressionReplacer(resolver);