Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 111 additions & 43 deletions FactoryGenerator/FactoryGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ private static IEnumerable<string> GenerateCode(ImmutableArray<InjectionData> da
using System;
using System.Linq;
using System.Collections.Generic;
using System.Collections.Immutable;
using FactoryGenerator;
using System.CodeDom.Compiler;
namespace {compilation.Assembly.Name}.Generated;
Expand Down Expand Up @@ -379,22 +380,11 @@ public bool GetBoolean(string key)

foreach (var parameter in constructorParameters.ToArray())
{
if (!parameter.IsEnumerable) continue;
if (parameter.EnumerableElementFullName is null) continue;
if (!parameter.IsCollection) continue;
if (parameter.CollectionElementFullName is null) continue;
var name = parameter.Name;
log.Log(LogLevel.Debug, $"Creating Array: {name} of type {parameter.EnumerableElementFullName}[]");
MakeArray(arrayDeclarations, name, parameter.EnumerableElementFullName, parameter.EnumerableElementMemberName!, interfaceInjectors);
constructorParameters.Remove(parameter);
localizedParameters.Add(parameter);
}

foreach (var parameter in constructorParameters.ToArray())
{
if (!parameter.IsArrayType) continue;
if (parameter.ArrayElementFullName is null) continue;
var name = parameter.Name;
log.Log(LogLevel.Debug, $"Creating Array: {name} of type {parameter.ArrayElementFullName}[]");
MakeArray(arrayDeclarations, name, parameter.ArrayElementFullName, parameter.ArrayElementMemberName!, interfaceInjectors);
log.Log(LogLevel.Debug, $"Creating Collection: {name} of element type {parameter.CollectionElementFullName}");
MakeArray(arrayDeclarations, name, parameter.CollectionElementFullName, parameter.CollectionElementMemberName!, interfaceInjectors);
constructorParameters.Remove(parameter);
localizedParameters.Add(parameter);
}
Expand Down Expand Up @@ -441,11 +431,15 @@ public bool GetBoolean(string key)
.Select(a => $"this.{a[1]} = Base.Resolve<{a[0]}>();"));

var interfacePairs = interfaceInjectors.Keys.Select(k => (TypeName: k, MemberName: interfaceMemberNames[k])).ToList();
var localizedPairs = localizedParameters.Select(p => (TypeName: p.TypeFullName, ParamName: p.Name)).ToList();
// ReadOnlySpan is a ref struct and cannot be placed in the lookup dictionary
var localizedForDict = localizedParameters.Where(p => p.CollectionKind != CollectionKind.ReadOnlySpan).ToList();
var localizedPairs = localizedForDict
.Select(p => (TypeName: p.TypeFullName, Expression: CollectionDictExpression(p.CollectionKind, p.Name)))
.ToList();
var requestedPairs = requestedUsages.Select(u => (TypeName: u.FullName, MemberName: u.MemberName)).ToList();
var constructorPairs = constructorParameters.Select(p => (TypeName: p.TypeFullName, ParamName: p.Name)).ToList();
var constructorPairs = constructorParameters.Select(p => (TypeName: p.TypeFullName, Expression: p.Name)).ToList();

var dictSize = interfaceInjectors.Count + localizedParameters.Count + requestedUsages.Count + constructorParameters.Count;
var dictSize = interfaceInjectors.Count + localizedForDict.Count + requestedUsages.Count + constructorParameters.Count;
yield return Constructor(usingStatements, constructorFields,
constructor, constructorAssignments,
dictSize, interfacePairs, localizedPairs, requestedPairs, constructorPairs,
Expand Down Expand Up @@ -607,19 +601,17 @@ private static void CheckForCycles(ImmutableArray<InjectionData> dataInjections)
foreach (var parameter in ctor.Parameters)
{
string? depName;
if (parameter.IsEnumerable)
if (parameter.IsCollection)
{
if (parameter.EnumerableElementFullName is null) continue;
depName = parameter.EnumerableElementFullName;
}
else if (parameter.IsArrayType)
{
if (parameter.ArrayElementFullName is null) continue;
depName = parameter.ArrayElementFullName;
if (parameter.CollectionElementFullName is null) continue;
depName = parameter.CollectionElementFullName;
}
else
{
depName = parameter.TypeFullName;
// Strip ? so nullable params resolve to their underlying type in the cycle graph
depName = parameter.IsNullable
? parameter.TypeFullName.TrimEnd('?')
: parameter.TypeFullName;
}

node.Add(depName);
Expand All @@ -638,8 +630,8 @@ private static void CheckForCycles(ImmutableArray<InjectionData> dataInjections)
}

private static string Constructor(string usingStatements, string constructorFields, string constructor, string constructorAssignments, int dictSize,
IEnumerable<(string TypeName, string MemberName)> interfaceTypePairs, IEnumerable<(string TypeName, string ParamName)> localizedParamPairs,
IEnumerable<(string TypeName, string MemberName)> requestedPairs, IEnumerable<(string TypeName, string ParamName)> constructorParamPairs,
IEnumerable<(string TypeName, string MemberName)> interfaceTypePairs, IEnumerable<(string TypeName, string Expression)> localizedParamPairs,
IEnumerable<(string TypeName, string MemberName)> requestedPairs, IEnumerable<(string TypeName, string Expression)> constructorParamPairs,
bool addLifetimeScopeFunction, string className, string? lifetimeParameters = null,
string? fromConstructor = null, string? resolvingConstructorAssignments = null, bool addMergingConstructor = true, List<string> booleans = null!)
{
Expand Down Expand Up @@ -787,11 +779,11 @@ private static string MakeDictionaryFromTypes(IEnumerable<(string TypeName, stri
return builder.ToString();
}

private static string MakeDictionaryFromParams(IEnumerable<(string TypeName, string ParamName)> pairs)
private static string MakeDictionaryFromParams(IEnumerable<(string TypeName, string Expression)> pairs)
{
var builder = new StringBuilder();
foreach (var (typeName, paramName) in pairs)
builder.AppendLine($"\t\t\t{{ typeof({typeName}), () => {paramName} }},");
foreach (var (typeName, expression) in pairs)
builder.AppendLine($"\t\t\t{{ typeof({typeName}), () => {expression} }},");
return builder.ToString();
}

Expand Down Expand Up @@ -821,48 +813,85 @@ private static string CreationCall(InjectionData injection, ImmutableArray<strin
throw new Exception(
$"Could not find any [Inject]ed implementations of {lambda.ContainingTypeFullName} to use as the source for the injection of {lambda.ContainingTypeFullName}.{lambda.MemberName}. Please provide at least one injection of the type {lambda.ContainingTypeFullName}.");

// Compute nullable defaults for lambda method parameters
HashSet<ParameterData>? lambdaNullableDefaults = null;
if (lambda.IsMethod && lambda.MethodParameters.Length > 0)
{
foreach (var p in lambda.MethodParameters)
{
if (!p.IsNullable) continue;
var baseType = p.TypeFullName.TrimEnd('?');
if (availableInterfaceFullNames.Contains(baseType)) continue;
lambdaNullableDefaults ??= new HashSet<ParameterData>();
lambdaNullableDefaults.Add(p);
}
}

if (lambda.IsMethod)
return $"{lambda.ContainingTypeMemberName}.{lambda.MemberName}{MakeMethodCall(lambda.MethodParameters, null)}";
return $"{lambda.ContainingTypeMemberName}.{lambda.MemberName}{MakeMethodCall(lambda.MethodParameters, null, lambdaNullableDefaults)}";
else
return $"{lambda.ContainingTypeMemberName}.{lambda.MemberName}";
}

HashSet<ParameterData>? missing = null;
var ctor = GetBestConstructor(injection, availableInterfaceFullNames, ref missing);
HashSet<ParameterData>? nullableDefaults = null;
var ctor = GetBestConstructor(injection, availableInterfaceFullNames, ref missing, ref nullableDefaults);
if (ctor is null)
throw new Exception($"No Construction method for {injection.TypeFullName}. Lambda was null.");

return $"new {injection.TypeFullName}{MakeConstructorCall(ctor, missing)}";
return $"new {injection.TypeFullName}{MakeConstructorCall(ctor, missing, nullableDefaults)}";
}

private static ConstructorData? GetBestConstructor(InjectionData injection,
ImmutableArray<string> availableInterfaceFullNames, ref HashSet<ParameterData>? missing)
ImmutableArray<string> availableInterfaceFullNames, ref HashSet<ParameterData>? missing,
ref HashSet<ParameterData>? nullableDefaults)
{
missing = null;
nullableDefaults = null;
ConstructorData? chosen = null;
foreach (var ctor in injection.Constructors)
{
var valid = true;
var localMissing = new HashSet<ParameterData>();
var localNullableDefaults = new HashSet<ParameterData>();
foreach (var parameter in ctor.Parameters)
{
if (availableInterfaceFullNames.Contains(parameter.TypeFullName)) continue;
// For nullable params, check availability of the underlying non-nullable type
var typeLookup = parameter.IsNullable
? parameter.TypeFullName.TrimEnd('?')
: parameter.TypeFullName;
if (availableInterfaceFullNames.Contains(typeLookup)) continue;
if (parameter.HasExplicitDefault) continue;
if (parameter.IsParams) continue;
// Collection params (IEnumerable<T>, T[], List<T>, etc.) are always satisfiable
// via MakeArray – add to missing for factory generation but keep constructor valid
if (parameter.IsCollection)
{
localMissing.Add(parameter);
continue;
}
// Nullable reference/value params that aren't registered default to null
if (parameter.IsNullable)
{
localNullableDefaults.Add(parameter);
continue;
}
valid = false;
localMissing.Add(parameter);
}

if (valid)
{
chosen = ctor;
missing = null;
missing = localMissing.Count > 0 ? localMissing : null;
nullableDefaults = localNullableDefaults.Count > 0 ? localNullableDefaults : null;
break;
}

if ((missing?.Count ?? int.MaxValue) <= localMissing.Count) continue;
chosen = ctor;
missing = localMissing;
nullableDefaults = localNullableDefaults.Count > 0 ? localNullableDefaults : null;
}
return chosen;
}
Expand All @@ -871,38 +900,77 @@ private static IEnumerable<ParameterData> GetBestConstructorMissing(InjectionDat
ImmutableArray<string> availableInterfaceFullNames)
{
HashSet<ParameterData>? missing = null;
GetBestConstructor(injection, availableInterfaceFullNames, ref missing);
HashSet<ParameterData>? nullableDefaults = null;
GetBestConstructor(injection, availableInterfaceFullNames, ref missing, ref nullableDefaults);
return missing ?? Enumerable.Empty<ParameterData>();
}

private static string MakeConstructorCall(ConstructorData ctor, HashSet<ParameterData>? missing)
private static string MakeConstructorCall(ConstructorData ctor, HashSet<ParameterData>? missing, HashSet<ParameterData>? nullableDefaults)
{
var args = new List<string>();
foreach (var parameter in ctor.Parameters)
{
if (nullableDefaults?.Contains(parameter) == true)
{
args.Add("null");
continue;
}
if (missing?.Contains(parameter) == true)
{
args.Add(parameter.Name);
args.Add(CollectionConstructorArg(parameter));
continue;
}
args.Add(parameter.TypeMemberName + "()");
}
return $"({string.Join(", ", args)})";
}

private static string MakeMethodCall(ImmutableArray<ParameterData> parameters, HashSet<ParameterData>? missing)
private static string MakeMethodCall(ImmutableArray<ParameterData> parameters, HashSet<ParameterData>? missing, HashSet<ParameterData>? nullableDefaults = null)
{
var args = new List<string>();
foreach (var parameter in parameters)
{
if (nullableDefaults?.Contains(parameter) == true)
{
args.Add("null");
continue;
}
if (missing?.Contains(parameter) == true)
{
args.Add(parameter.Name);
args.Add(CollectionConstructorArg(parameter));
continue;
}
args.Add(parameter.TypeMemberName + "()");
}
return $"({string.Join(", ", args)})";
}

/// <summary>
/// Returns the expression to use when passing a collection (or plain-missing) parameter
/// in a generated constructor call. Collection params are converted from the cached
/// IEnumerable&lt;T&gt; factory to the exact type requested.
/// </summary>
private static string CollectionConstructorArg(ParameterData parameter) =>
parameter.CollectionKind switch
{
CollectionKind.Array => $"{parameter.Name}.ToArray()",
CollectionKind.List => $"{parameter.Name}.ToList()",
CollectionKind.ImmutableArray => $"ImmutableArray.CreateRange({parameter.Name})",
CollectionKind.ReadOnlySpan => $"new global::System.ReadOnlySpan<{parameter.CollectionElementFullName}>({parameter.Name}.ToArray())",
_ => parameter.Name, // Enumerable or plain missing → use name directly
};

/// <summary>
/// Returns the expression used in the Func&lt;object&gt; lambda inside the lookup dictionary
/// for a localized (collection) parameter.
/// </summary>
private static string CollectionDictExpression(CollectionKind kind, string factoryName) =>
kind switch
{
CollectionKind.Array => $"{factoryName}.ToArray()",
CollectionKind.List => $"{factoryName}.ToList()",
CollectionKind.ImmutableArray => $"ImmutableArray.CreateRange({factoryName})",
_ => factoryName, // Enumerable → direct
};
}
}
31 changes: 16 additions & 15 deletions FactoryGenerator/Injection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,29 +136,30 @@ private static ParameterData ExtractParameter(IParameterSymbol parameter)
{
var typeFullName = parameter.Type.ToString()!;
var typeMemberName = SymbolUtility.MemberName(parameter.Type).Replace("()", "");
var isNullable = parameter.Type.NullableAnnotation == NullableAnnotation.Annotated;

var isEnumerable = SymbolUtility.IsEnumerable(parameter.Type);
string? enumElemFull = null, enumElemMember = null;
if (isEnumerable && parameter.Type is INamedTypeSymbol namedEnum && namedEnum.TypeArguments.Length == 1)
var collectionKind = SymbolUtility.GetCollectionKind(parameter.Type);
string? elemFull = null, elemMember = null;
if (collectionKind != CollectionKind.None)
{
var elem = namedEnum.TypeArguments[0];
enumElemFull = elem.ToString()!;
enumElemMember = SymbolUtility.MemberName(elem).Replace("()", "");
}
ITypeSymbol? elemType = null;
if (parameter.Type is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1)
elemType = namedType.TypeArguments[0];
else if (parameter.Type is IArrayTypeSymbol arrType)
elemType = arrType.ElementType;

var isArray = parameter.Type is IArrayTypeSymbol;
string? arrElemFull = null, arrElemMember = null;
if (isArray && parameter.Type is IArrayTypeSymbol arrType && arrType.ElementType is INamedTypeSymbol arrElem)
{
arrElemFull = arrElem.ToString()!;
arrElemMember = SymbolUtility.MemberName(arrElem).Replace("()", "");
if (elemType is not null)
{
elemFull = elemType.ToString()!;
elemMember = SymbolUtility.MemberName(elemType).Replace("()", "");
}
}

return new ParameterData(
typeFullName, typeMemberName,
parameter.HasExplicitDefaultValue, parameter.IsParams, parameter.Name,
isEnumerable, enumElemFull, enumElemMember,
isArray, arrElemFull, arrElemMember);
collectionKind, elemFull, elemMember,
isNullable);
}

private static BooleanInjection? HandleBoolean(AttributeData attributeData)
Expand Down
Loading
Loading