diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ExternalAssemblyErrorTypeTests.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ExternalAssemblyErrorTypeTests.cs
new file mode 100644
index 0000000..f129ac8
--- /dev/null
+++ b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ExternalAssemblyErrorTypeTests.cs
@@ -0,0 +1,175 @@
+using FluentAssertions;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+
+namespace CSharpFunctionalExtensions.HttpResults.Generators.Tests;
+
+///
+/// Tests that verify the source generator correctly handles error types defined in external assemblies.
+/// This addresses the issue where short type names were used instead of fully qualified names,
+/// causing compilation errors when error types came from referenced assemblies.
+///
+public class ExternalAssemblyErrorTypeTests
+{
+ [Fact]
+ public void GeneratesFullyQualifiedTypeNamesForExternalErrorType_WithToOkHttpResult()
+ {
+ // Create a fake external assembly with an error type
+ var externalErrorTypeCode = """
+ namespace MyApp.Infrastructure.Errors;
+
+ public sealed record NotFoundError(string Message);
+ """;
+
+ var externalAssemblySyntaxTree = CSharpSyntaxTree.ParseText(externalErrorTypeCode);
+ var externalCompilation = CSharpCompilation.Create(
+ "ExternalAssembly",
+ [externalAssemblySyntaxTree],
+ [MetadataReference.CreateFromFile(typeof(object).Assembly.Location)],
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
+ );
+
+ // Emit the external assembly to get a reference
+ using var ms = new MemoryStream();
+ var emitResult = externalCompilation.Emit(ms);
+ emitResult.Success.Should().BeTrue("External assembly should compile");
+
+ var externalAssemblyReference = MetadataReference.CreateFromStream(new MemoryStream(ms.ToArray()));
+
+ // Create mapper in the main assembly that references the external error type
+ var mapperCode = """
+ using Microsoft.AspNetCore.Http;
+ using Microsoft.AspNetCore.Http.HttpResults;
+ using CSharpFunctionalExtensions.HttpResults;
+ using MyApp.Infrastructure.Errors;
+
+ namespace MyApp.Api.ErrorMappers;
+
+ public class NotFoundErrorMapper : IResultErrorMapper
+ {
+ public ProblemHttpResult Map(NotFoundError error) =>
+ TypedResults.Problem(
+ statusCode: StatusCodes.Status404NotFound,
+ title: "Not Found",
+ detail: error.Message
+ );
+ }
+ """;
+
+ var (_, generatedSource) = GeneratorTestHelper.RunGenerator(mapperCode, [externalAssemblyReference]);
+
+ generatedSource.Should().Contain("this Result result");
+ }
+
+ [Fact]
+ public void GeneratesFullyQualifiedTypeNamesForExternalErrorType_MultipleMappers()
+ {
+ // Create fake external assemblies with error types
+ var externalErrorTypesCode = """
+ namespace MyApp.Infrastructure.Errors;
+
+ public sealed record NotFoundError(string Message);
+ public sealed record ValidationError(string Message);
+ """;
+
+ var externalAssemblySyntaxTree = CSharpSyntaxTree.ParseText(externalErrorTypesCode);
+ var externalCompilation = CSharpCompilation.Create(
+ "ExternalAssembly",
+ [externalAssemblySyntaxTree],
+ [MetadataReference.CreateFromFile(typeof(object).Assembly.Location)],
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
+ );
+
+ using var ms = new MemoryStream();
+ var emitResult = externalCompilation.Emit(ms);
+ emitResult.Success.Should().BeTrue();
+
+ var externalAssemblyReference = MetadataReference.CreateFromStream(new MemoryStream(ms.ToArray()));
+
+ // Create multiple mappers
+ var mappersCode = """
+ using Microsoft.AspNetCore.Http;
+ using Microsoft.AspNetCore.Http.HttpResults;
+ using CSharpFunctionalExtensions.HttpResults;
+ using MyApp.Infrastructure.Errors;
+
+ namespace MyApp.Api.ErrorMappers;
+
+ public class NotFoundErrorMapper : IResultErrorMapper
+ {
+ public ProblemHttpResult Map(NotFoundError error) =>
+ TypedResults.Problem(
+ statusCode: StatusCodes.Status404NotFound,
+ title: "Not Found",
+ detail: error.Message
+ );
+ }
+
+ public class ValidationErrorMapper : IResultErrorMapper>
+ {
+ public BadRequest Map(ValidationError error) =>
+ TypedResults.BadRequest(
+ TypedResults.Problem(
+ statusCode: StatusCodes.Status400BadRequest,
+ title: "Validation Error",
+ detail: error.Message
+ )
+ );
+ }
+ """;
+
+ var (_, generatedSource) = GeneratorTestHelper.RunGenerator(mappersCode, [externalAssemblyReference]);
+
+ generatedSource.Should().Contain("this Result
+ {
+ public ProblemHttpResult Map(NotFoundError error) =>
+ TypedResults.Problem(
+ statusCode: StatusCodes.Status404NotFound,
+ title: "Not Found",
+ detail: error.Message
+ );
+ }
+ """;
+
+ var (diagnostics, generatedSource) = GeneratorTestHelper.RunGenerator(mapperCode, [externalAssemblyReference]);
+
+ diagnostics.Should().BeEmpty("Should not report diagnostics about NotFoundError being unresolved");
+ generatedSource.Should().NotBeNullOrEmpty("Should generate extension methods");
+ }
+}
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/GeneratorTestHelper.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/GeneratorTestHelper.cs
new file mode 100644
index 0000000..8486c56
--- /dev/null
+++ b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/GeneratorTestHelper.cs
@@ -0,0 +1,44 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+
+namespace CSharpFunctionalExtensions.HttpResults.Generators.Tests;
+
+public static class GeneratorTestHelper
+{
+ public static (IEnumerable Diagnostics, string GeneratedSource) RunGenerator(
+ string sourceCode,
+ IEnumerable? additionalReferences = null
+ )
+ {
+ var syntaxTree = CSharpSyntaxTree.ParseText(sourceCode);
+
+ var references = new List
+ {
+ MetadataReference.CreateFromFile(typeof(ResultExtensionsGenerator).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IResultErrorMapper<,>).Assembly.Location),
+ };
+
+ if (additionalReferences != null)
+ references.AddRange(additionalReferences);
+
+ var compilation = CSharpCompilation.Create(
+ "TestAssembly",
+ [syntaxTree],
+ references.OfType(),
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
+ );
+
+ var generator = new ResultExtensionsGenerator();
+ var driver = CSharpGeneratorDriver.Create(generator);
+
+ driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics);
+
+ var sourceFiles = outputCompilation
+ .SyntaxTrees.Where(tree => tree.FilePath.EndsWith(".g.cs", StringComparison.OrdinalIgnoreCase))
+ .Select(tree => tree.GetText().ToString());
+
+ var generatedSource = string.Join("\n\n", sourceFiles);
+
+ return (diagnostics, generatedSource);
+ }
+}
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ResultExtensionsGeneratorTestHelper.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ResultExtensionsGeneratorTestHelper.cs
deleted file mode 100644
index 60589e6..0000000
--- a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/ResultExtensionsGeneratorTestHelper.cs
+++ /dev/null
@@ -1,32 +0,0 @@
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-
-namespace CSharpFunctionalExtensions.HttpResults.Generators.Tests;
-
-public static class ResultExtensionsGeneratorTestHelper
-{
- public static IEnumerable RunGenerator(string sourceCode)
- {
- var syntaxTree = CSharpSyntaxTree.ParseText(sourceCode);
-
- var references = new[]
- {
- MetadataReference.CreateFromFile(typeof(ResultExtensionsGenerator).Assembly.Location),
- MetadataReference.CreateFromFile(typeof(IResultErrorMapper<,>).Assembly.Location),
- }.ToList();
-
- var compilation = CSharpCompilation.Create(
- "TestAssembly",
- [syntaxTree],
- references,
- new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
- );
-
- var generator = new ResultExtensionsGenerator();
- var driver = CSharpGeneratorDriver.Create(generator);
-
- driver.RunGeneratorsAndUpdateCompilation(compilation, out _, out var diagnostics);
-
- return diagnostics;
- }
-}
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs
index d28f178..48e19d6 100644
--- a/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs
+++ b/CSharpFunctionalExtensions.HttpResults.Generators.Tests/Rules/DuplicateMapperRuleTests.cs
@@ -29,11 +29,13 @@ public class DocumentCreationErrorMapper2 : IResultErrorMapper _mapperClasses;
- private readonly HashSet _requiredNamespaces;
- protected ClassBuilder(HashSet requiredNamespaces, List mapperClasses)
+ protected ClassBuilder(List mapperClasses, Compilation? compilation = null)
{
- _requiredNamespaces = requiredNamespaces;
_mapperClasses = mapperClasses;
+ _compilation = compilation;
}
private static string DefaultUsings =>
"""
using CSharpFunctionalExtensions;
- using IResult = Microsoft.AspNetCore.Http.IResult;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Net.Http.Headers;
using System.Text;
+ using IResult = Microsoft.AspNetCore.Http.IResult;
""";
public string SourceFileName => $"{ClassName}.g.cs";
@@ -42,13 +44,6 @@ public string Build()
sourceBuilder.AppendLine();
sourceBuilder.AppendLine(DefaultUsings);
- _requiredNamespaces
- .Where(@namespace => !@namespace.StartsWith("global"))
- .Distinct()
- .Select(@namespace => $"using {@namespace};")
- .ToList()
- .ForEach(@using => sourceBuilder.AppendLine(@using));
-
sourceBuilder.AppendLine();
sourceBuilder.AppendLine(ClassSummary);
@@ -68,8 +63,8 @@ public string Build()
if (mappingMethod.ParameterList.Parameters.Count != 1)
throw new ArgumentException($"Mapping method in class {mapperClassName} must have exactly one parameter.");
- var resultErrorType = mappingMethod.ParameterList.Parameters[0].Type!.ToString();
- var httpResultType = mappingMethod.ReturnType.ToString();
+ var resultErrorType = GetFullyQualifiedTypeName(mapperClass, mappingMethod.ParameterList.Parameters[0].Type!);
+ var httpResultType = mappingMethod.ReturnType!.ToString();
foreach (var methodGenerator in MethodGenerators)
{
@@ -83,4 +78,18 @@ public string Build()
return sourceBuilder.ToString();
}
+
+ private string GetFullyQualifiedTypeName(ClassDeclarationSyntax mapperClass, TypeSyntax typeSyntax)
+ {
+ if (_compilation == null)
+ return typeSyntax.ToString();
+
+ var semanticModel = _compilation.GetSemanticModel(mapperClass.SyntaxTree);
+ var typeInfo = semanticModel.GetTypeInfo(typeSyntax);
+
+ if (typeInfo.Type == null)
+ return typeSyntax.ToString();
+
+ return TypeNameResolver.GetFullyQualifiedTypeName(typeInfo.Type);
+ }
}
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs
index cc1bf74..86ae6a7 100644
--- a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs
+++ b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/ResultExtensionsClassBuilder.cs
@@ -1,12 +1,11 @@
using CSharpFunctionalExtensions.HttpResults.Generators.ResultExtensions;
+using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace CSharpFunctionalExtensions.HttpResults.Generators.Builders;
-public class ResultExtensionsClassBuilder(
- HashSet requiredNamespaces,
- List mapperClasses
-) : ClassBuilder(requiredNamespaces, mapperClasses)
+public class ResultExtensionsClassBuilder(List mapperClasses, Compilation? compilation = null)
+ : ClassBuilder(mapperClasses, compilation)
{
protected override string ClassName => "ResultExtensions";
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs
index 4406352..7f03bd1 100644
--- a/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs
+++ b/CSharpFunctionalExtensions.HttpResults.Generators/Builders/UnitResultExtensionsClassBuilder.cs
@@ -1,12 +1,13 @@
using CSharpFunctionalExtensions.HttpResults.Generators.UnitResultExtensions;
+using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace CSharpFunctionalExtensions.HttpResults.Generators.Builders;
public class UnitResultExtensionsClassBuilder(
- HashSet requiredNamespaces,
- List mapperClasses
-) : ClassBuilder(requiredNamespaces, mapperClasses)
+ List mapperClasses,
+ Compilation? compilation = null
+) : ClassBuilder(mapperClasses, compilation)
{
protected override string ClassName => "UnitResultExtensions";
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs b/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs
index 5313fd8..1e99b2d 100644
--- a/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs
+++ b/CSharpFunctionalExtensions.HttpResults.Generators/ResultExtensionsGenerator.cs
@@ -1,6 +1,6 @@
-using System.Diagnostics;
-using System.Text;
+using System.Text;
using CSharpFunctionalExtensions.HttpResults.Generators.Builders;
+using CSharpFunctionalExtensions.HttpResults.Generators.Utils;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
@@ -44,19 +44,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var (compilation, classDeclarations) = source;
var mapperClasses = new List();
- var requiredNamespaces = new HashSet();
Parallel.ForEach(
classDeclarations,
classDeclaration =>
{
- var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
- var namespaceName = GetNamespace(classDeclaration, semanticModel);
-
lock (mapperClasses)
{
mapperClasses.Add(classDeclaration);
- requiredNamespaces.Add(namespaceName);
}
}
);
@@ -64,13 +59,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
if (!ResultExtensionsGeneratorValidator.CheckRules(mapperClasses, context))
return;
- var (fileName, sourceText) = CreateErrorMapperInstancesClass(mapperClasses, requiredNamespaces);
+ var (fileName, sourceText) = CreateErrorMapperInstancesClass(mapperClasses, compilation);
context.AddSource(fileName, SourceText.From(sourceText, Encoding.UTF8));
var classBuilders = new List
{
- new ResultExtensionsClassBuilder(requiredNamespaces, mapperClasses),
- new UnitResultExtensionsClassBuilder(requiredNamespaces, mapperClasses),
+ new ResultExtensionsClassBuilder(mapperClasses, compilation),
+ new UnitResultExtensionsClassBuilder(mapperClasses, compilation),
};
foreach (var classBuilder in classBuilders)
@@ -84,7 +79,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
///
private static (string FileName, string SourceText) CreateErrorMapperInstancesClass(
List mapperClasses,
- HashSet requiredNamespaces
+ Compilation compilation
)
{
var sourceBuilder = new StringBuilder();
@@ -93,20 +88,20 @@ HashSet requiredNamespaces
sourceBuilder.AppendLine();
sourceBuilder.AppendLine("#nullable enable");
sourceBuilder.AppendLine();
- requiredNamespaces
- .Where(@namespace => !@namespace.StartsWith("global"))
- .Distinct()
- .Select(@namespace => $"using {@namespace};")
- .ToList()
- .ForEach(@using => sourceBuilder.AppendLine(@using));
sourceBuilder.AppendLine();
sourceBuilder.AppendLine("public static class ErrorMapperInstances {");
- foreach (var mapperName in mapperClasses)
- sourceBuilder.AppendLine(
- $" public static {mapperName.Identifier.Text} {mapperName.Identifier.Text} {{ get; }} = new();"
- );
+ foreach (var mapper in mapperClasses)
+ {
+ var semanticModel = compilation.GetSemanticModel(mapper.SyntaxTree);
+
+ if (semanticModel.GetDeclaredSymbol(mapper) is not ITypeSymbol mapperSymbol)
+ continue;
+
+ var mapperType = TypeNameResolver.GetFullyQualifiedTypeName(mapperSymbol);
+ sourceBuilder.AppendLine($" public static {mapperType} {mapper.Identifier.Text} {{ get; }} = new();");
+ }
sourceBuilder.AppendLine("}");
@@ -128,16 +123,4 @@ private static bool ImplementsResultErrorMapper(ITypeSymbol? classSymbol)
interfaceSymbol.Name.StartsWith(ResultErrorMapperInterface)
);
}
-
- ///
- /// Retrieves the namespace of a class declaration.
- ///
- /// The class declaration syntax node.
- /// The semantic model for the syntax tree.
- /// The namespace of the class, or an empty string if the namespace cannot be determined.
- private static string GetNamespace(ClassDeclarationSyntax classDeclaration, SemanticModel semanticModel)
- {
- var symbol = semanticModel.GetDeclaredSymbol(classDeclaration);
- return symbol?.ContainingNamespace?.ToString() ?? string.Empty;
- }
}
diff --git a/CSharpFunctionalExtensions.HttpResults.Generators/Utils/TypeNameResolver.cs b/CSharpFunctionalExtensions.HttpResults.Generators/Utils/TypeNameResolver.cs
new file mode 100644
index 0000000..28c6e80
--- /dev/null
+++ b/CSharpFunctionalExtensions.HttpResults.Generators/Utils/TypeNameResolver.cs
@@ -0,0 +1,17 @@
+using Microsoft.CodeAnalysis;
+
+namespace CSharpFunctionalExtensions.HttpResults.Generators.Utils;
+
+internal static class TypeNameResolver
+{
+ private static readonly SymbolDisplayFormat FullyQualifiedWithNullables =
+ SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions(
+ SymbolDisplayFormat.FullyQualifiedFormat.MiscellaneousOptions
+ | SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
+ );
+
+ public static string GetFullyQualifiedTypeName(ITypeSymbol typeSymbol)
+ {
+ return typeSymbol.ToDisplayString(FullyQualifiedWithNullables);
+ }
+}