diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ExternalAssemblyErrorTypeTests.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ExternalAssemblyErrorTypeTests.cs new file mode 100644 index 0000000..f129ac8 --- /dev/null +++ b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ExternalAssemblyErrorTypeTests.cs @@ -0,0 +1,175 @@ +using FluentAssertions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace CSharpFunctionalExtensions.HttpResults.Generators.Tests; + +/// +/// Tests that verify the source generator correctly handles error types defined in external assemblies. +/// This addresses the issue where short type names were used instead of fully qualified names, +/// causing compilation errors when error types came from referenced assemblies. +/// +public class ExternalAssemblyErrorTypeTests +{ + [Fact] + public void GeneratesFullyQualifiedTypeNamesForExternalErrorType_WithToOkHttpResult() + { + // Create a fake external assembly with an error type + var externalErrorTypeCode = """ + namespace MyApp.Infrastructure.Errors; + + public sealed record NotFoundError(string Message); + """; + + var externalAssemblySyntaxTree = CSharpSyntaxTree.ParseText(externalErrorTypeCode); + var externalCompilation = CSharpCompilation.Create( + "ExternalAssembly", + [externalAssemblySyntaxTree], + [MetadataReference.CreateFromFile(typeof(object).Assembly.Location)], + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + ); + + // Emit the external assembly to get a reference + using var ms = new MemoryStream(); + var emitResult = externalCompilation.Emit(ms); + emitResult.Success.Should().BeTrue("External assembly should compile"); + + var externalAssemblyReference = MetadataReference.CreateFromStream(new MemoryStream(ms.ToArray())); + + // Create mapper in the main assembly that references the external error type + var mapperCode = """ + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.HttpResults; + using CSharpFunctionalExtensions.HttpResults; + using MyApp.Infrastructure.Errors; + + namespace MyApp.Api.ErrorMappers; + + public class NotFoundErrorMapper : IResultErrorMapper + { + public ProblemHttpResult Map(NotFoundError error) => + TypedResults.Problem( + statusCode: StatusCodes.Status404NotFound, + title: "Not Found", + detail: error.Message + ); + } + """; + + var (_, generatedSource) = GeneratorTestHelper.RunGenerator(mapperCode, [externalAssemblyReference]); + + generatedSource.Should().Contain("this Result result"); + } + + [Fact] + public void GeneratesFullyQualifiedTypeNamesForExternalErrorType_MultipleMappers() + { + // Create fake external assemblies with error types + var externalErrorTypesCode = """ + namespace MyApp.Infrastructure.Errors; + + public sealed record NotFoundError(string Message); + public sealed record ValidationError(string Message); + """; + + var externalAssemblySyntaxTree = CSharpSyntaxTree.ParseText(externalErrorTypesCode); + var externalCompilation = CSharpCompilation.Create( + "ExternalAssembly", + [externalAssemblySyntaxTree], + [MetadataReference.CreateFromFile(typeof(object).Assembly.Location)], + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + ); + + using var ms = new MemoryStream(); + var emitResult = externalCompilation.Emit(ms); + emitResult.Success.Should().BeTrue(); + + var externalAssemblyReference = MetadataReference.CreateFromStream(new MemoryStream(ms.ToArray())); + + // Create multiple mappers + var mappersCode = """ + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.HttpResults; + using CSharpFunctionalExtensions.HttpResults; + using MyApp.Infrastructure.Errors; + + namespace MyApp.Api.ErrorMappers; + + public class NotFoundErrorMapper : IResultErrorMapper + { + public ProblemHttpResult Map(NotFoundError error) => + TypedResults.Problem( + statusCode: StatusCodes.Status404NotFound, + title: "Not Found", + detail: error.Message + ); + } + + public class ValidationErrorMapper : IResultErrorMapper> + { + public BadRequest Map(ValidationError error) => + TypedResults.BadRequest( + TypedResults.Problem( + statusCode: StatusCodes.Status400BadRequest, + title: "Validation Error", + detail: error.Message + ) + ); + } + """; + + var (_, generatedSource) = GeneratorTestHelper.RunGenerator(mappersCode, [externalAssemblyReference]); + + generatedSource.Should().Contain("this Result + { + public ProblemHttpResult Map(NotFoundError error) => + TypedResults.Problem( + statusCode: StatusCodes.Status404NotFound, + title: "Not Found", + detail: error.Message + ); + } + """; + + var (diagnostics, generatedSource) = GeneratorTestHelper.RunGenerator(mapperCode, [externalAssemblyReference]); + + diagnostics.Should().BeEmpty("Should not report diagnostics about NotFoundError being unresolved"); + generatedSource.Should().NotBeNullOrEmpty("Should generate extension methods"); + } +} diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/GeneratorTestHelper.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/GeneratorTestHelper.cs new file mode 100644 index 0000000..8486c56 --- /dev/null +++ b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/GeneratorTestHelper.cs @@ -0,0 +1,44 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace CSharpFunctionalExtensions.HttpResults.Generators.Tests; + +public static class GeneratorTestHelper +{ + public static (IEnumerable Diagnostics, string GeneratedSource) RunGenerator( + string sourceCode, + IEnumerable? additionalReferences = null + ) + { + var syntaxTree = CSharpSyntaxTree.ParseText(sourceCode); + + var references = new List + { + MetadataReference.CreateFromFile(typeof(ResultExtensionsGenerator).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IResultErrorMapper<,>).Assembly.Location), + }; + + if (additionalReferences != null) + references.AddRange(additionalReferences); + + var compilation = CSharpCompilation.Create( + "TestAssembly", + [syntaxTree], + references.OfType(), + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + ); + + var generator = new ResultExtensionsGenerator(); + var driver = CSharpGeneratorDriver.Create(generator); + + driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics); + + var sourceFiles = outputCompilation + .SyntaxTrees.Where(tree => tree.FilePath.EndsWith(".g.cs", StringComparison.OrdinalIgnoreCase)) + .Select(tree => tree.GetText().ToString()); + + var generatedSource = string.Join("\n\n", sourceFiles); + + return (diagnostics, generatedSource); + } +} diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ResultExtensionsGeneratorTestHelper.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ResultExtensionsGeneratorTestHelper.cs deleted file mode 100644 index 60589e6..0000000 --- a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ResultExtensionsGeneratorTestHelper.cs +++ /dev/null @@ -1,32 +0,0 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; - -namespace CSharpFunctionalExtensions.HttpResults.Generators.Tests; - -public static class ResultExtensionsGeneratorTestHelper -{ - public static IEnumerable RunGenerator(string sourceCode) - { - var syntaxTree = CSharpSyntaxTree.ParseText(sourceCode); - - var references = new[] - { - MetadataReference.CreateFromFile(typeof(ResultExtensionsGenerator).Assembly.Location), - MetadataReference.CreateFromFile(typeof(IResultErrorMapper<,>).Assembly.Location), - }.ToList(); - - var compilation = CSharpCompilation.Create( - "TestAssembly", - [syntaxTree], - references, - new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) - ); - - var generator = new ResultExtensionsGenerator(); - var driver = CSharpGeneratorDriver.Create(generator); - - driver.RunGeneratorsAndUpdateCompilation(compilation, out _, out var diagnostics); - - return diagnostics; - } -} diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs index d28f178..48e19d6 100644 --- a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs +++ b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs @@ -29,11 +29,13 @@ public class DocumentCreationErrorMapper2 : IResultErrorMapper _mapperClasses; - private readonly HashSet _requiredNamespaces; - protected ClassBuilder(HashSet requiredNamespaces, List mapperClasses) + protected ClassBuilder(List mapperClasses, Compilation? compilation = null) { - _requiredNamespaces = requiredNamespaces; _mapperClasses = mapperClasses; + _compilation = compilation; } private static string DefaultUsings => """ using CSharpFunctionalExtensions; - using IResult = Microsoft.AspNetCore.Http.IResult; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.Net.Http.Headers; using System.Text; + using IResult = Microsoft.AspNetCore.Http.IResult; """; public string SourceFileName => $"{ClassName}.g.cs"; @@ -42,13 +44,6 @@ public string Build() sourceBuilder.AppendLine(); sourceBuilder.AppendLine(DefaultUsings); - _requiredNamespaces - .Where(@namespace => !@namespace.StartsWith("global")) - .Distinct() - .Select(@namespace => $"using {@namespace};") - .ToList() - .ForEach(@using => sourceBuilder.AppendLine(@using)); - sourceBuilder.AppendLine(); sourceBuilder.AppendLine(ClassSummary); @@ -68,8 +63,8 @@ public string Build() if (mappingMethod.ParameterList.Parameters.Count != 1) throw new ArgumentException($"Mapping method in class {mapperClassName} must have exactly one parameter."); - var resultErrorType = mappingMethod.ParameterList.Parameters[0].Type!.ToString(); - var httpResultType = mappingMethod.ReturnType.ToString(); + var resultErrorType = GetFullyQualifiedTypeName(mapperClass, mappingMethod.ParameterList.Parameters[0].Type!); + var httpResultType = mappingMethod.ReturnType!.ToString(); foreach (var methodGenerator in MethodGenerators) { @@ -83,4 +78,18 @@ public string Build() return sourceBuilder.ToString(); } + + private string GetFullyQualifiedTypeName(ClassDeclarationSyntax mapperClass, TypeSyntax typeSyntax) + { + if (_compilation == null) + return typeSyntax.ToString(); + + var semanticModel = _compilation.GetSemanticModel(mapperClass.SyntaxTree); + var typeInfo = semanticModel.GetTypeInfo(typeSyntax); + + if (typeInfo.Type == null) + return typeSyntax.ToString(); + + return TypeNameResolver.GetFullyQualifiedTypeName(typeInfo.Type); + } } diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs index cc1bf74..86ae6a7 100644 --- a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs +++ b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs @@ -1,12 +1,11 @@ using CSharpFunctionalExtensions.HttpResults.Generators.ResultExtensions; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace CSharpFunctionalExtensions.HttpResults.Generators.Builders; -public class ResultExtensionsClassBuilder( - HashSet requiredNamespaces, - List mapperClasses -) : ClassBuilder(requiredNamespaces, mapperClasses) +public class ResultExtensionsClassBuilder(List mapperClasses, Compilation? compilation = null) + : ClassBuilder(mapperClasses, compilation) { protected override string ClassName => "ResultExtensions"; diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs index 4406352..7f03bd1 100644 --- a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs +++ b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs @@ -1,12 +1,13 @@ using CSharpFunctionalExtensions.HttpResults.Generators.UnitResultExtensions; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace CSharpFunctionalExtensions.HttpResults.Generators.Builders; public class UnitResultExtensionsClassBuilder( - HashSet requiredNamespaces, - List mapperClasses -) : ClassBuilder(requiredNamespaces, mapperClasses) + List mapperClasses, + Compilation? compilation = null +) : ClassBuilder(mapperClasses, compilation) { protected override string ClassName => "UnitResultExtensions"; diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs b/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs index 5313fd8..1e99b2d 100644 --- a/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs +++ b/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs @@ -1,6 +1,6 @@ -using System.Diagnostics; -using System.Text; +using System.Text; using CSharpFunctionalExtensions.HttpResults.Generators.Builders; +using CSharpFunctionalExtensions.HttpResults.Generators.Utils; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; @@ -44,19 +44,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var (compilation, classDeclarations) = source; var mapperClasses = new List(); - var requiredNamespaces = new HashSet(); Parallel.ForEach( classDeclarations, classDeclaration => { - var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree); - var namespaceName = GetNamespace(classDeclaration, semanticModel); - lock (mapperClasses) { mapperClasses.Add(classDeclaration); - requiredNamespaces.Add(namespaceName); } } ); @@ -64,13 +59,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context) if (!ResultExtensionsGeneratorValidator.CheckRules(mapperClasses, context)) return; - var (fileName, sourceText) = CreateErrorMapperInstancesClass(mapperClasses, requiredNamespaces); + var (fileName, sourceText) = CreateErrorMapperInstancesClass(mapperClasses, compilation); context.AddSource(fileName, SourceText.From(sourceText, Encoding.UTF8)); var classBuilders = new List { - new ResultExtensionsClassBuilder(requiredNamespaces, mapperClasses), - new UnitResultExtensionsClassBuilder(requiredNamespaces, mapperClasses), + new ResultExtensionsClassBuilder(mapperClasses, compilation), + new UnitResultExtensionsClassBuilder(mapperClasses, compilation), }; foreach (var classBuilder in classBuilders) @@ -84,7 +79,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) /// private static (string FileName, string SourceText) CreateErrorMapperInstancesClass( List mapperClasses, - HashSet requiredNamespaces + Compilation compilation ) { var sourceBuilder = new StringBuilder(); @@ -93,20 +88,20 @@ HashSet requiredNamespaces sourceBuilder.AppendLine(); sourceBuilder.AppendLine("#nullable enable"); sourceBuilder.AppendLine(); - requiredNamespaces - .Where(@namespace => !@namespace.StartsWith("global")) - .Distinct() - .Select(@namespace => $"using {@namespace};") - .ToList() - .ForEach(@using => sourceBuilder.AppendLine(@using)); sourceBuilder.AppendLine(); sourceBuilder.AppendLine("public static class ErrorMapperInstances {"); - foreach (var mapperName in mapperClasses) - sourceBuilder.AppendLine( - $" public static {mapperName.Identifier.Text} {mapperName.Identifier.Text} {{ get; }} = new();" - ); + foreach (var mapper in mapperClasses) + { + var semanticModel = compilation.GetSemanticModel(mapper.SyntaxTree); + + if (semanticModel.GetDeclaredSymbol(mapper) is not ITypeSymbol mapperSymbol) + continue; + + var mapperType = TypeNameResolver.GetFullyQualifiedTypeName(mapperSymbol); + sourceBuilder.AppendLine($" public static {mapperType} {mapper.Identifier.Text} {{ get; }} = new();"); + } sourceBuilder.AppendLine("}"); @@ -128,16 +123,4 @@ private static bool ImplementsResultErrorMapper(ITypeSymbol? classSymbol) interfaceSymbol.Name.StartsWith(ResultErrorMapperInterface) ); } - - /// - /// Retrieves the namespace of a class declaration. - /// - /// The class declaration syntax node. - /// The semantic model for the syntax tree. - /// The namespace of the class, or an empty string if the namespace cannot be determined. - private static string GetNamespace(ClassDeclarationSyntax classDeclaration, SemanticModel semanticModel) - { - var symbol = semanticModel.GetDeclaredSymbol(classDeclaration); - return symbol?.ContainingNamespace?.ToString() ?? string.Empty; - } } diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/Utils/TypeNameResolver.cs b/CSharpFunctionalExtensions.HttpResults.Generators/Utils/TypeNameResolver.cs new file mode 100644 index 0000000..28c6e80 --- /dev/null +++ b/CSharpFunctionalExtensions.HttpResults.Generators/Utils/TypeNameResolver.cs @@ -0,0 +1,17 @@ +using Microsoft.CodeAnalysis; + +namespace CSharpFunctionalExtensions.HttpResults.Generators.Utils; + +internal static class TypeNameResolver +{ + private static readonly SymbolDisplayFormat FullyQualifiedWithNullables = + SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions( + SymbolDisplayFormat.FullyQualifiedFormat.MiscellaneousOptions + | SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier + ); + + public static string GetFullyQualifiedTypeName(ITypeSymbol typeSymbol) + { + return typeSymbol.ToDisplayString(FullyQualifiedWithNullables); + } +}