diff --git a/FactoryGenerator/FactoryGenerator.cs b/FactoryGenerator/FactoryGenerator.cs index bdcfe98..6ab7f2d 100644 --- a/FactoryGenerator/FactoryGenerator.cs +++ b/FactoryGenerator/FactoryGenerator.cs @@ -154,6 +154,7 @@ private static IEnumerable GenerateCode(ImmutableArray da using System; using System.Linq; using System.Collections.Generic; +using System.Collections.Immutable; using FactoryGenerator; using System.CodeDom.Compiler; namespace {compilation.Assembly.Name}.Generated; @@ -379,22 +380,11 @@ public bool GetBoolean(string key) foreach (var parameter in constructorParameters.ToArray()) { - if (!parameter.IsEnumerable) continue; - if (parameter.EnumerableElementFullName is null) continue; + if (!parameter.IsCollection) continue; + if (parameter.CollectionElementFullName is null) continue; var name = parameter.Name; - log.Log(LogLevel.Debug, $"Creating Array: {name} of type {parameter.EnumerableElementFullName}[]"); - MakeArray(arrayDeclarations, name, parameter.EnumerableElementFullName, parameter.EnumerableElementMemberName!, interfaceInjectors); - constructorParameters.Remove(parameter); - localizedParameters.Add(parameter); - } - - foreach (var parameter in constructorParameters.ToArray()) - { - if (!parameter.IsArrayType) continue; - if (parameter.ArrayElementFullName is null) continue; - var name = parameter.Name; - log.Log(LogLevel.Debug, $"Creating Array: {name} of type {parameter.ArrayElementFullName}[]"); - MakeArray(arrayDeclarations, name, parameter.ArrayElementFullName, parameter.ArrayElementMemberName!, interfaceInjectors); + log.Log(LogLevel.Debug, $"Creating Collection: {name} of element type {parameter.CollectionElementFullName}"); + MakeArray(arrayDeclarations, name, parameter.CollectionElementFullName, parameter.CollectionElementMemberName!, interfaceInjectors); constructorParameters.Remove(parameter); localizedParameters.Add(parameter); } @@ -441,11 +431,15 @@ public bool GetBoolean(string key) .Select(a => $"this.{a[1]} = Base.Resolve<{a[0]}>();")); var interfacePairs = interfaceInjectors.Keys.Select(k => (TypeName: k, MemberName: interfaceMemberNames[k])).ToList(); - var localizedPairs = localizedParameters.Select(p => (TypeName: p.TypeFullName, ParamName: p.Name)).ToList(); + // ReadOnlySpan is a ref struct and cannot be placed in the lookup dictionary + var localizedForDict = localizedParameters.Where(p => p.CollectionKind != CollectionKind.ReadOnlySpan).ToList(); + var localizedPairs = localizedForDict + .Select(p => (TypeName: p.TypeFullName, Expression: CollectionDictExpression(p.CollectionKind, p.Name))) + .ToList(); var requestedPairs = requestedUsages.Select(u => (TypeName: u.FullName, MemberName: u.MemberName)).ToList(); - var constructorPairs = constructorParameters.Select(p => (TypeName: p.TypeFullName, ParamName: p.Name)).ToList(); + var constructorPairs = constructorParameters.Select(p => (TypeName: p.TypeFullName, Expression: p.Name)).ToList(); - var dictSize = interfaceInjectors.Count + localizedParameters.Count + requestedUsages.Count + constructorParameters.Count; + var dictSize = interfaceInjectors.Count + localizedForDict.Count + requestedUsages.Count + constructorParameters.Count; yield return Constructor(usingStatements, constructorFields, constructor, constructorAssignments, dictSize, interfacePairs, localizedPairs, requestedPairs, constructorPairs, @@ -607,19 +601,17 @@ private static void CheckForCycles(ImmutableArray dataInjections) foreach (var parameter in ctor.Parameters) { string? depName; - if (parameter.IsEnumerable) + if (parameter.IsCollection) { - if (parameter.EnumerableElementFullName is null) continue; - depName = parameter.EnumerableElementFullName; - } - else if (parameter.IsArrayType) - { - if (parameter.ArrayElementFullName is null) continue; - depName = parameter.ArrayElementFullName; + if (parameter.CollectionElementFullName is null) continue; + depName = parameter.CollectionElementFullName; } else { - depName = parameter.TypeFullName; + // Strip ? so nullable params resolve to their underlying type in the cycle graph + depName = parameter.IsNullable + ? parameter.TypeFullName.TrimEnd('?') + : parameter.TypeFullName; } node.Add(depName); @@ -638,8 +630,8 @@ private static void CheckForCycles(ImmutableArray dataInjections) } private static string Constructor(string usingStatements, string constructorFields, string constructor, string constructorAssignments, int dictSize, - IEnumerable<(string TypeName, string MemberName)> interfaceTypePairs, IEnumerable<(string TypeName, string ParamName)> localizedParamPairs, - IEnumerable<(string TypeName, string MemberName)> requestedPairs, IEnumerable<(string TypeName, string ParamName)> constructorParamPairs, + IEnumerable<(string TypeName, string MemberName)> interfaceTypePairs, IEnumerable<(string TypeName, string Expression)> localizedParamPairs, + IEnumerable<(string TypeName, string MemberName)> requestedPairs, IEnumerable<(string TypeName, string Expression)> constructorParamPairs, bool addLifetimeScopeFunction, string className, string? lifetimeParameters = null, string? fromConstructor = null, string? resolvingConstructorAssignments = null, bool addMergingConstructor = true, List booleans = null!) { @@ -787,11 +779,11 @@ private static string MakeDictionaryFromTypes(IEnumerable<(string TypeName, stri return builder.ToString(); } - private static string MakeDictionaryFromParams(IEnumerable<(string TypeName, string ParamName)> pairs) + private static string MakeDictionaryFromParams(IEnumerable<(string TypeName, string Expression)> pairs) { var builder = new StringBuilder(); - foreach (var (typeName, paramName) in pairs) - builder.AppendLine($"\t\t\t{{ typeof({typeName}), () => {paramName} }},"); + foreach (var (typeName, expression) in pairs) + builder.AppendLine($"\t\t\t{{ typeof({typeName}), () => {expression} }},"); return builder.ToString(); } @@ -821,34 +813,69 @@ private static string CreationCall(InjectionData injection, ImmutableArray? lambdaNullableDefaults = null; + if (lambda.IsMethod && lambda.MethodParameters.Length > 0) + { + foreach (var p in lambda.MethodParameters) + { + if (!p.IsNullable) continue; + var baseType = p.TypeFullName.TrimEnd('?'); + if (availableInterfaceFullNames.Contains(baseType)) continue; + lambdaNullableDefaults ??= new HashSet(); + lambdaNullableDefaults.Add(p); + } + } + if (lambda.IsMethod) - return $"{lambda.ContainingTypeMemberName}.{lambda.MemberName}{MakeMethodCall(lambda.MethodParameters, null)}"; + return $"{lambda.ContainingTypeMemberName}.{lambda.MemberName}{MakeMethodCall(lambda.MethodParameters, null, lambdaNullableDefaults)}"; else return $"{lambda.ContainingTypeMemberName}.{lambda.MemberName}"; } HashSet? missing = null; - var ctor = GetBestConstructor(injection, availableInterfaceFullNames, ref missing); + HashSet? nullableDefaults = null; + var ctor = GetBestConstructor(injection, availableInterfaceFullNames, ref missing, ref nullableDefaults); if (ctor is null) throw new Exception($"No Construction method for {injection.TypeFullName}. Lambda was null."); - return $"new {injection.TypeFullName}{MakeConstructorCall(ctor, missing)}"; + return $"new {injection.TypeFullName}{MakeConstructorCall(ctor, missing, nullableDefaults)}"; } private static ConstructorData? GetBestConstructor(InjectionData injection, - ImmutableArray availableInterfaceFullNames, ref HashSet? missing) + ImmutableArray availableInterfaceFullNames, ref HashSet? missing, + ref HashSet? nullableDefaults) { missing = null; + nullableDefaults = null; ConstructorData? chosen = null; foreach (var ctor in injection.Constructors) { var valid = true; var localMissing = new HashSet(); + var localNullableDefaults = new HashSet(); foreach (var parameter in ctor.Parameters) { - if (availableInterfaceFullNames.Contains(parameter.TypeFullName)) continue; + // For nullable params, check availability of the underlying non-nullable type + var typeLookup = parameter.IsNullable + ? parameter.TypeFullName.TrimEnd('?') + : parameter.TypeFullName; + if (availableInterfaceFullNames.Contains(typeLookup)) continue; if (parameter.HasExplicitDefault) continue; if (parameter.IsParams) continue; + // Collection params (IEnumerable, T[], List, etc.) are always satisfiable + // via MakeArray – add to missing for factory generation but keep constructor valid + if (parameter.IsCollection) + { + localMissing.Add(parameter); + continue; + } + // Nullable reference/value params that aren't registered default to null + if (parameter.IsNullable) + { + localNullableDefaults.Add(parameter); + continue; + } valid = false; localMissing.Add(parameter); } @@ -856,13 +883,15 @@ private static string CreationCall(InjectionData injection, ImmutableArray 0 ? localMissing : null; + nullableDefaults = localNullableDefaults.Count > 0 ? localNullableDefaults : null; break; } if ((missing?.Count ?? int.MaxValue) <= localMissing.Count) continue; chosen = ctor; missing = localMissing; + nullableDefaults = localNullableDefaults.Count > 0 ? localNullableDefaults : null; } return chosen; } @@ -871,18 +900,24 @@ private static IEnumerable GetBestConstructorMissing(InjectionDat ImmutableArray availableInterfaceFullNames) { HashSet? missing = null; - GetBestConstructor(injection, availableInterfaceFullNames, ref missing); + HashSet? nullableDefaults = null; + GetBestConstructor(injection, availableInterfaceFullNames, ref missing, ref nullableDefaults); return missing ?? Enumerable.Empty(); } - private static string MakeConstructorCall(ConstructorData ctor, HashSet? missing) + private static string MakeConstructorCall(ConstructorData ctor, HashSet? missing, HashSet? nullableDefaults) { var args = new List(); foreach (var parameter in ctor.Parameters) { + if (nullableDefaults?.Contains(parameter) == true) + { + args.Add("null"); + continue; + } if (missing?.Contains(parameter) == true) { - args.Add(parameter.Name); + args.Add(CollectionConstructorArg(parameter)); continue; } args.Add(parameter.TypeMemberName + "()"); @@ -890,19 +925,52 @@ private static string MakeConstructorCall(ConstructorData ctor, HashSet parameters, HashSet? missing) + private static string MakeMethodCall(ImmutableArray parameters, HashSet? missing, HashSet? nullableDefaults = null) { var args = new List(); foreach (var parameter in parameters) { + if (nullableDefaults?.Contains(parameter) == true) + { + args.Add("null"); + continue; + } if (missing?.Contains(parameter) == true) { - args.Add(parameter.Name); + args.Add(CollectionConstructorArg(parameter)); continue; } args.Add(parameter.TypeMemberName + "()"); } return $"({string.Join(", ", args)})"; } + + /// + /// Returns the expression to use when passing a collection (or plain-missing) parameter + /// in a generated constructor call. Collection params are converted from the cached + /// IEnumerable<T> factory to the exact type requested. + /// + private static string CollectionConstructorArg(ParameterData parameter) => + parameter.CollectionKind switch + { + CollectionKind.Array => $"{parameter.Name}.ToArray()", + CollectionKind.List => $"{parameter.Name}.ToList()", + CollectionKind.ImmutableArray => $"ImmutableArray.CreateRange({parameter.Name})", + CollectionKind.ReadOnlySpan => $"new global::System.ReadOnlySpan<{parameter.CollectionElementFullName}>({parameter.Name}.ToArray())", + _ => parameter.Name, // Enumerable or plain missing → use name directly + }; + + /// + /// Returns the expression used in the Func<object> lambda inside the lookup dictionary + /// for a localized (collection) parameter. + /// + private static string CollectionDictExpression(CollectionKind kind, string factoryName) => + kind switch + { + CollectionKind.Array => $"{factoryName}.ToArray()", + CollectionKind.List => $"{factoryName}.ToList()", + CollectionKind.ImmutableArray => $"ImmutableArray.CreateRange({factoryName})", + _ => factoryName, // Enumerable → direct + }; } } \ No newline at end of file diff --git a/FactoryGenerator/Injection.cs b/FactoryGenerator/Injection.cs index 241ebbf..890bf13 100644 --- a/FactoryGenerator/Injection.cs +++ b/FactoryGenerator/Injection.cs @@ -136,29 +136,30 @@ private static ParameterData ExtractParameter(IParameterSymbol parameter) { var typeFullName = parameter.Type.ToString()!; var typeMemberName = SymbolUtility.MemberName(parameter.Type).Replace("()", ""); + var isNullable = parameter.Type.NullableAnnotation == NullableAnnotation.Annotated; - var isEnumerable = SymbolUtility.IsEnumerable(parameter.Type); - string? enumElemFull = null, enumElemMember = null; - if (isEnumerable && parameter.Type is INamedTypeSymbol namedEnum && namedEnum.TypeArguments.Length == 1) + var collectionKind = SymbolUtility.GetCollectionKind(parameter.Type); + string? elemFull = null, elemMember = null; + if (collectionKind != CollectionKind.None) { - var elem = namedEnum.TypeArguments[0]; - enumElemFull = elem.ToString()!; - enumElemMember = SymbolUtility.MemberName(elem).Replace("()", ""); - } + ITypeSymbol? elemType = null; + if (parameter.Type is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + elemType = namedType.TypeArguments[0]; + else if (parameter.Type is IArrayTypeSymbol arrType) + elemType = arrType.ElementType; - var isArray = parameter.Type is IArrayTypeSymbol; - string? arrElemFull = null, arrElemMember = null; - if (isArray && parameter.Type is IArrayTypeSymbol arrType && arrType.ElementType is INamedTypeSymbol arrElem) - { - arrElemFull = arrElem.ToString()!; - arrElemMember = SymbolUtility.MemberName(arrElem).Replace("()", ""); + if (elemType is not null) + { + elemFull = elemType.ToString()!; + elemMember = SymbolUtility.MemberName(elemType).Replace("()", ""); + } } return new ParameterData( typeFullName, typeMemberName, parameter.HasExplicitDefaultValue, parameter.IsParams, parameter.Name, - isEnumerable, enumElemFull, enumElemMember, - isArray, arrElemFull, arrElemMember); + collectionKind, elemFull, elemMember, + isNullable); } private static BooleanInjection? HandleBoolean(AttributeData attributeData) diff --git a/FactoryGenerator/InjectionData.cs b/FactoryGenerator/InjectionData.cs index 7b77c6f..2ad2cc1 100644 --- a/FactoryGenerator/InjectionData.cs +++ b/FactoryGenerator/InjectionData.cs @@ -5,6 +5,8 @@ namespace FactoryGenerator { + public enum CollectionKind { None, Enumerable, Array, List, ImmutableArray, ReadOnlySpan } + public sealed class InjectionData : IEquatable { public string TypeFullName { get; } @@ -87,34 +89,35 @@ public bool Equals(ConstructorData? other) public sealed class ParameterData : IEquatable { public string TypeFullName { get; } - public string TypeMemberName { get; } // MemberName(param.Type) without "()" + public string TypeMemberName { get; } // MemberName(param.Type) without "()" public bool HasExplicitDefault { get; } public bool IsParams { get; } - public string Name { get; } // parameter.Name - public bool IsEnumerable { get; } - public string? EnumerableElementFullName { get; } - public string? EnumerableElementMemberName { get; } // without "()" - public bool IsArrayType { get; } - public string? ArrayElementFullName { get; } - public string? ArrayElementMemberName { get; } // without "()" + public string Name { get; } // parameter.Name + public CollectionKind CollectionKind { get; } + public string? CollectionElementFullName { get; } + public string? CollectionElementMemberName { get; } // without "()" + public bool IsNullable { get; } + + // Convenience + public bool IsCollection => CollectionKind != CollectionKind.None; + public bool IsEnumerable => CollectionKind == CollectionKind.Enumerable; + public bool IsArrayType => CollectionKind == CollectionKind.Array; public ParameterData( string typeFullName, string typeMemberName, bool hasExplicitDefault, bool isParams, string name, - bool isEnumerable, string? enumerableElementFullName, string? enumerableElementMemberName, - bool isArrayType, string? arrayElementFullName, string? arrayElementMemberName) + CollectionKind collectionKind, string? collectionElementFullName, string? collectionElementMemberName, + bool isNullable) { TypeFullName = typeFullName; TypeMemberName = typeMemberName; HasExplicitDefault = hasExplicitDefault; IsParams = isParams; Name = name; - IsEnumerable = isEnumerable; - EnumerableElementFullName = enumerableElementFullName; - EnumerableElementMemberName = enumerableElementMemberName; - IsArrayType = isArrayType; - ArrayElementFullName = arrayElementFullName; - ArrayElementMemberName = arrayElementMemberName; + CollectionKind = collectionKind; + CollectionElementFullName = collectionElementFullName; + CollectionElementMemberName = collectionElementMemberName; + IsNullable = isNullable; } public bool Equals(ParameterData? other) @@ -126,12 +129,10 @@ public bool Equals(ParameterData? other) && HasExplicitDefault == other.HasExplicitDefault && IsParams == other.IsParams && Name == other.Name - && IsEnumerable == other.IsEnumerable - && EnumerableElementFullName == other.EnumerableElementFullName - && EnumerableElementMemberName == other.EnumerableElementMemberName - && IsArrayType == other.IsArrayType - && ArrayElementFullName == other.ArrayElementFullName - && ArrayElementMemberName == other.ArrayElementMemberName; + && CollectionKind == other.CollectionKind + && CollectionElementFullName == other.CollectionElementFullName + && CollectionElementMemberName == other.CollectionElementMemberName + && IsNullable == other.IsNullable; } public override bool Equals(object? obj) => obj is ParameterData other && Equals(other); diff --git a/FactoryGenerator/SymbolUtility.cs b/FactoryGenerator/SymbolUtility.cs index 3784d67..f28762d 100644 --- a/FactoryGenerator/SymbolUtility.cs +++ b/FactoryGenerator/SymbolUtility.cs @@ -51,16 +51,33 @@ public static IEnumerable GetAllTypes(INamedTypeSymbol root) } } - internal static bool IsEnumerable(ITypeSymbol symbol) + internal static bool IsEnumerable(ITypeSymbol symbol) => + GetCollectionKind(symbol) == CollectionKind.Enumerable; + + internal static CollectionKind GetCollectionKind(ITypeSymbol symbol) { - if (symbol.SpecialType == SpecialType.System_Collections_IEnumerable) return true; + if (symbol is IArrayTypeSymbol) return CollectionKind.Array; + + if (symbol.SpecialType == SpecialType.System_Collections_IEnumerable) + return CollectionKind.Enumerable; + if (symbol is INamedTypeSymbol named) { var fullName = named.ConstructedFrom.ToDisplayString(); - if (fullName == "System.Collections.Generic.IEnumerable") return true; + switch (fullName) + { + case "System.Collections.Generic.IEnumerable": return CollectionKind.Enumerable; + case "System.Collections.Generic.List": return CollectionKind.List; + case "System.Collections.Immutable.ImmutableArray": return CollectionKind.ImmutableArray; + case "System.ReadOnlySpan": return CollectionKind.ReadOnlySpan; + } } - return symbol.Name == "IEnumerable" && symbol.ContainingNamespace?.ToDisplayString() is - "System.Collections.Generic" or "System.Collections"; + + if (symbol.Name == "IEnumerable" && symbol.ContainingNamespace?.ToDisplayString() is + "System.Collections.Generic" or "System.Collections") + return CollectionKind.Enumerable; + + return CollectionKind.None; } public static string MemberName(ISymbol? type) diff --git a/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs b/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs index c877477..5cceb2e 100644 --- a/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs +++ b/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs @@ -259,6 +259,46 @@ public void HierarchicalContainersPropgatesBooleansUnknownToIt() newContainer.GetBoolean("B").ShouldBe(true); newContainer.GetBoolean("C").ShouldBe(false); } + + // ── Nullable parameter tests ────────────────────────────────────────────── + + [Fact] + public void NullableUnregisteredParameterDefaultsToNull() + { + m_container.Resolve().Optional.ShouldBeNull(); + } + + [Fact] + public void NullableRegisteredParameterIsResolved() + { + m_container.Resolve().Optional.ShouldBeOfType(); + } + + // ── Collection constructor parameter tests ──────────────────────────────── + + [Fact] + public void ArrayConstructorParameterIsResolved() + { + m_container.Resolve().Arrays.Length.ShouldBe(3); + } + + [Fact] + public void ListConstructorParameterIsResolved() + { + m_container.Resolve().Arrays.Count.ShouldBe(3); + } + + [Fact] + public void ImmutableArrayConstructorParameterIsResolved() + { + m_container.Resolve().Arrays.Length.ShouldBe(3); + } + + [Fact] + public void ReadOnlySpanConstructorParameterIsResolved() + { + m_container.Resolve().Count.ShouldBe(3); + } private class DummyContainer : IContainer { public const string DummyText = "I am a bit of text"; diff --git a/Tests/TestData/Inherited/Types.cs b/Tests/TestData/Inherited/Types.cs index 51e5442..70b1483 100644 --- a/Tests/TestData/Inherited/Types.cs +++ b/Tests/TestData/Inherited/Types.cs @@ -1,4 +1,6 @@ using FactoryGenerator.Attributes; +using System.Collections.Generic; +using System.Collections.Immutable; namespace Inherited; @@ -187,4 +189,52 @@ public void Dispose() { WasDisposed = true; } +} + +// ── Nullable parameter tests ───────────────────────────────────────────────── + +/// Interface with no [Inject] implementation — intentionally unregistered. +public interface INullableOptional; + +[Inject, Self] +public class NullableConsumer(INullableOptional? optional) +{ + public INullableOptional? Optional { get; } = optional; +} + +public interface INullablePresent; + +[Inject] +public class NullablePresent : INullablePresent; + +[Inject, Self] +public class NullablePresentConsumer(INullablePresent? optional) +{ + public INullablePresent? Optional { get; } = optional; +} + +// ── Additional collection type tests ───────────────────────────────────────── + +[Inject, Self] +public class ArrayParameterConsumer(IArray[] arrays) +{ + public IArray[] Arrays { get; } = arrays; +} + +[Inject, Self] +public class ListConsumer(List arrays) +{ + public List Arrays { get; } = arrays; +} + +[Inject, Self] +public class ImmutableArrayConsumer(ImmutableArray arrays) +{ + public ImmutableArray Arrays { get; } = arrays; +} + +[Inject, Self] +public class ReadOnlySpanConsumer(ReadOnlySpan arrays) +{ + public int Count { get; } = arrays.Length; } \ No newline at end of file