Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<TargetFrameworks>net8.0;net10.0</TargetFrameworks>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<LangVersion>12.0</LangVersion>
<LangVersion Condition="'$(TargetFramework)' == 'net10.0'">14.0</LangVersion>
<Nullable>enable</Nullable>
<EnableNETAnalyzers>true</EnableNETAnalyzers>
<NoWarn>CS1591</NoWarn>
Expand Down
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
<PackageVersion Include="BenchmarkDotNet" Version="0.14.0" />
<PackageVersion Include="coverlet.collector" Version="6.0.2" />
<PackageVersion Include="Microsoft.Extensions.Logging.Console" Version="8.0.1" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.11.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="5.0.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.Analyzers" Version="3.11.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="8.0.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ public class ExpressionSyntaxRewriter : CSharpSyntaxRewriter
readonly NullConditionalRewriteSupport _nullConditionalRewriteSupport;
readonly SourceProductionContext _context;
readonly Stack<ExpressionSyntax> _conditionalAccessExpressionsStack = new();
readonly string? _extensionParameterName;

public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullConditionalRewriteSupport nullConditionalRewriteSupport, SemanticModel semanticModel, SourceProductionContext context)
public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullConditionalRewriteSupport nullConditionalRewriteSupport, SemanticModel semanticModel, SourceProductionContext context, string? extensionParameterName = null)
{
_targetTypeSymbol = targetTypeSymbol;
_nullConditionalRewriteSupport = nullConditionalRewriteSupport;
_semanticModel = semanticModel;
_context = context;
_extensionParameterName = extensionParameterName;
}

private SyntaxNode? VisitThisBaseExpression(CSharpSyntaxNode node)
Expand Down Expand Up @@ -281,8 +283,22 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition

public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
var symbol = _semanticModel.GetSymbolInfo(node).Symbol;
if (symbol is not null)
// Handle C# 14 extension parameter replacement (e.g., `e` in `extension(Entity e)` becomes `@this`)
if (_extensionParameterName is not null && node.Identifier.Text == _extensionParameterName)
{
var symbol = _semanticModel.GetSymbolInfo(node).Symbol;

// Check if this identifier refers to the extension parameter
if (symbol is IParameterSymbol { ContainingSymbol: INamedTypeSymbol { IsExtension: true } })
{
return SyntaxFactory.IdentifierName("@this")
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
}

var identifierSymbol = _semanticModel.GetSymbolInfo(node).Symbol;
if (identifierSymbol is not null)
{
var operation = node switch { { Parent: { } parent } when parent.IsKind(SyntaxKind.InvocationExpression) => _semanticModel.GetOperation(node.Parent),
_ => _semanticModel.GetOperation(node!)
Expand Down Expand Up @@ -337,10 +353,10 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition
}

// if this node refers to a named type which is not yet fully qualified, we want to fully qualify it
if (symbol.Kind is SymbolKind.NamedType && node.Parent?.Kind() is not SyntaxKind.QualifiedName)
if (identifierSymbol.Kind is SymbolKind.NamedType && node.Parent?.Kind() is not SyntaxKind.QualifiedName)
{
return SyntaxFactory.ParseTypeName(
symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
identifierSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
).WithLeadingTrivia(node.GetLeadingTrivia());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ static IEnumerable<string> GetNestedInClassPath(ITypeSymbol namedTypeSymbol)
yield return namedTypeSymbol.Name;
}

/// <summary>
/// Gets the nested class path for extension members, skipping the extension block itself
/// and using the outer class as the containing type.
/// </summary>
static IEnumerable<string> GetNestedInClassPathForExtensionMember(ITypeSymbol extensionType)
{
// For extension members, the ContainingType is the extension block,
// and its ContainingType is the outer class (e.g., EntityExtensions)
var outerType = extensionType.ContainingType;

if (outerType is not null)
{
return GetNestedInClassPath(outerType);
}

return [];
}

public static ProjectableDescriptor? GetDescriptor(Compilation compilation, MemberDeclarationSyntax member, SourceProductionContext context)
{
var semanticModel = compilation.GetSemanticModel(member.SyntaxTree);
Expand Down Expand Up @@ -115,24 +133,52 @@ x is IPropertySymbol xProperty &&
if (memberBody is null) return null;
}

var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, semanticModel, context);
// Check if this member is inside a C# 14 extension block
var isExtensionMember = memberSymbol.ContainingType is { IsExtension: true };
IParameterSymbol? extensionParameter = null;
ITypeSymbol? extensionReceiverType = null;

if (isExtensionMember && memberSymbol.ContainingType is { } extensionType)
{
extensionParameter = extensionType.ExtensionParameter;
extensionReceiverType = extensionParameter?.Type;
}

// For extension members, use the extension receiver type for rewriting
var targetTypeForRewriting = isExtensionMember && extensionReceiverType is INamedTypeSymbol receiverNamedType
? receiverNamedType
: memberSymbol.ContainingType;

var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(
targetTypeForRewriting,
nullConditionalRewriteSupport,
semanticModel,
context,
extensionParameter?.Name);
var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel);

var descriptor = new ProjectableDescriptor {
// For extension members, use the outer class for class naming
var classForNaming = isExtensionMember && memberSymbol.ContainingType.ContainingType is not null
? memberSymbol.ContainingType.ContainingType
: memberSymbol.ContainingType;

var descriptor = new ProjectableDescriptor
{
UsingDirectives = member.SyntaxTree.GetRoot().DescendantNodes().OfType<UsingDirectiveSyntax>(),
ClassName = memberSymbol.ContainingType.Name,
ClassNamespace = memberSymbol.ContainingType.ContainingNamespace.IsGlobalNamespace ? null : memberSymbol.ContainingType.ContainingNamespace.ToDisplayString(),
ClassName = classForNaming.Name,
ClassNamespace = classForNaming.ContainingNamespace.IsGlobalNamespace ? null : classForNaming.ContainingNamespace.ToDisplayString(),
MemberName = memberSymbol.Name,
NestedInClassNames = GetNestedInClassPath(memberSymbol.ContainingType),
NestedInClassNames = isExtensionMember
? GetNestedInClassPathForExtensionMember(memberSymbol.ContainingType)
: GetNestedInClassPath(memberSymbol.ContainingType),
ParametersList = SyntaxFactory.ParameterList()
};

if (memberSymbol.ContainingType is INamedTypeSymbol { IsGenericType: true } containingNamedType)
if (classForNaming is { IsGenericType: true })
{
descriptor.ClassTypeParameterList = SyntaxFactory.TypeParameterList();

foreach (var additionalClassTypeParameter in containingNamedType.TypeParameters)
foreach (var additionalClassTypeParameter in classForNaming.TypeParameters)
{
descriptor.ClassTypeParameterList = descriptor.ClassTypeParameterList.AddParameters(
SyntaxFactory.TypeParameter(additionalClassTypeParameter.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
Expand Down Expand Up @@ -182,7 +228,21 @@ x is IPropertySymbol xProperty &&
}
}

if (!member.Modifiers.Any(SyntaxKind.StaticKeyword))
// Handle extension members - add @this parameter with the extension receiver type
if (isExtensionMember && extensionReceiverType is not null)
{
descriptor.ParametersList = descriptor.ParametersList.AddParameters(
SyntaxFactory.Parameter(
SyntaxFactory.Identifier("@this")
)
.WithType(
SyntaxFactory.ParseTypeName(
extensionReceiverType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
)
)
);
}
else if (!member.Modifiers.Any(SyntaxKind.StaticKeyword))
{
descriptor.ParametersList = descriptor.ParametersList.AddParameters(
SyntaxFactory.Parameter(
Expand All @@ -198,7 +258,13 @@ x is IPropertySymbol xProperty &&

var methodSymbol = memberSymbol as IMethodSymbol;

if (methodSymbol is { IsExtensionMethod: true })
// Handle target type for extension members
if (isExtensionMember && extensionReceiverType is not null)
{
descriptor.TargetClassNamespace = extensionReceiverType.ContainingNamespace.IsGlobalNamespace ? null : extensionReceiverType.ContainingNamespace.ToDisplayString();
descriptor.TargetNestedInClassNames = GetNestedInClassPath(extensionReceiverType);
}
else if (methodSymbol is { IsExtensionMethod: true })
{
var targetTypeSymbol = methodSymbol.Parameters.First().Type;
descriptor.TargetClassNamespace = targetTypeSymbol.ContainingNamespace.IsGlobalNamespace ? null : targetTypeSymbol.ContainingNamespace.ToDisplayString();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#if NET10_0_OR_GREATER
namespace EntityFrameworkCore.Projectables.FunctionalTests.ExtensionMembers
{
public class Entity
{
public int Id { get; set; }
public string Name { get; set; } = string.Empty;
}
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#if NET10_0_OR_GREATER
namespace EntityFrameworkCore.Projectables.FunctionalTests.ExtensionMembers
{
public static class EntityExtensions
{
extension(Entity e)
{
/// <summary>
/// Extension member property that doubles the entity's ID.
/// </summary>
[Projectable]
public int DoubleId => e.Id * 2;

/// <summary>
/// Extension member method that triples the entity's ID.
/// </summary>
[Projectable]
public int TripleId() => e.Id * 3;

/// <summary>
/// Extension member method that multiplies the entity's ID by a factor.
/// </summary>
[Projectable]
public int Multiply(int factor) => e.Id * factor;
}
}

public static class IntExtensions
{
extension(int i)
{
/// <summary>
/// Extension member property that squares an integer.
/// </summary>
[Projectable]
public int SquaredMember => i * i;
}
}
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id] * 3
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id] * 3
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id] * 5
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id] * 5
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#if NET10_0_OR_GREATER
using System.Linq;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
using Microsoft.EntityFrameworkCore;
using VerifyXunit;
using Xunit;

namespace EntityFrameworkCore.Projectables.FunctionalTests.ExtensionMembers
{
/// <summary>
/// Tests for C# 14 extension member support.
/// These tests only run on .NET 10+ where extension members are supported.
/// Note: Extension properties cannot currently be used directly in LINQ expression trees (CS9296),
/// so only extension methods are tested here.
/// </summary>
[UsesVerify]
public class ExtensionMemberTests
{
[Fact]
public Task ExtensionMemberMethodOnEntity()
{
using var dbContext = new SampleDbContext<Entity>();

var query = dbContext.Set<Entity>()
.Select(x => x.TripleId());

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ExtensionMemberMethodWithParameterOnEntity()
{
using var dbContext = new SampleDbContext<Entity>();

var query = dbContext.Set<Entity>()
.Select(x => x.Multiply(5));

return Verifier.Verify(query.ToQueryString());
}
}
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// <auto-generated/>
#nullable disable
using System;
using EntityFrameworkCore.Projectables;
using Foo;

namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_EntityExtensions_TripleId
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Entity, int>> Expression()
{
return (global::Foo.Entity @this) => @this.Id * 3;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// <auto-generated/>
#nullable disable
using System;
using EntityFrameworkCore.Projectables;
using Foo;

namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_EntityExtensions_TripleId
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Entity, int>> Expression()
{
return (global::Foo.Entity @this) => @this.Id * 3;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// <auto-generated/>
#nullable disable
using System;
using EntityFrameworkCore.Projectables;
using Foo;

namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_EntityExtensions_Multiply
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Entity, int, int>> Expression()
{
return (global::Foo.Entity @this, int factor) => @this.Id * factor;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// <auto-generated/>
#nullable disable
using System;
using EntityFrameworkCore.Projectables;
using Foo;

namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_EntityExtensions_Multiply
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Entity, int, int>> Expression()
{
return (global::Foo.Entity @this, int factor) => @this.Id * factor;
}
}
}
Loading
Loading