之前运行在服务器上的程序 , 注册 IoC, Optons 都是通过扫描目录下的DLL, 然后遍历包含特定 Attribute 的类, 在跟据 Attribute 里的参数去注册.
AOT不支持动态加载, 同时会打包成一个文件, 所以不可能在去扫描DLL, 所以之前的方法, 基本行不通.
这里使用了 IIncrementalGenerator
, 而不是 ISourceGenerator
, 是因为这两个东西的支持文档都不友好 ,研究一个都耗费了好长时间.
[Generator(LanguageNames.CSharp)]
public class RegistServiceGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
//Debugger.Launch();
var provider = context.CompilationProvider.Select((compilation, cts) =>
{
var ns1 = new List<INamespaceSymbol>() {
compilation.Assembly.GlobalNamespace
};
var ns2 = compilation.SourceModule
.ReferencedAssemblySymbols
.Where(s => s.Name.Contains("CNB.PubTools"))
.Select(s => s.GlobalNamespace);
var ns = ns1.Concat(ns2);
var symbols = ns.SelectMany(s => Helper.GetAllTypeSymbol(s));
return Helper.Transform(symbols.ToArray());
});
context.RegisterSourceOutput(provider, (spc, datas) =>
{
if (datas.Any())
{
var code = Helper.GetCode(datas);
spc.AddSource("Regist.g.cs", code);
}
});
}
}
上面这段代码, 主要是:
1, 从 当前 及 当前引用的所有 的 Assembly
里找到所有的 GlobalNamespace
。
2, 遍历这些命名空间下的 Symbol
。
3, 把这些 Symbol
转换为中间数据。
4, 把中间数据通过 SourceOutput 生成代码。
遍历所有命名空间, 找到所有的 Symbol
:
public static IEnumerable<INamedTypeSymbol> GetAllTypeSymbol(INamespaceSymbol namespaceSymbol)
{
var typeMemberList = namespaceSymbol.GetTypeMembers();
foreach (var typeSymbol in typeMemberList)
{
yield return typeSymbol;
}
foreach (var namespaceMember in namespaceSymbol.GetNamespaceMembers())
{
foreach (var typeSymbol in GetAllTypeSymbol(namespaceMember))
{
yield return typeSymbol;
}
}
}
转换为中间数据:
private static readonly string REGISTATTRIBUTE_NAME = "CNB.PubTools.Common.Attributes.RegistAttribute";
...
public static IEnumerable<RegistDescriptor> Transform(params INamedTypeSymbol[] symbols)
{
if (symbols?.Any() != true)
return Enumerable.Empty<RegistDescriptor>();
var typeSymbols = symbols
.Select(s => new
{
ClassName = s.ToDisplayString(),
Attributes = s.GetAttributes().Where(a => a.AttributeClass.ToDisplayString() == REGISTATTRIBUTE_NAME || a.AttributeClass.BaseType.ToDisplayString() == REGISTATTRIBUTE_NAME)
})
.Where(s => s.Attributes?.Any() == true);
var descriptors = new List<RegistDescriptor>();
foreach (var typeSymbol in typeSymbols)
{
var tmp = typeSymbol.Attributes.Select(a =>
{
var name = a.AttributeClass.Name;
return name switch
{
"RegistAttribute" => a.AttributeClass.IsGenericType
? new RegistDescriptor()
{
Lifetime = (int)a.ConstructorArguments[0].Value,
ForType = a.AttributeClass.TypeArguments.First().ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name
}
: new RegistDescriptor()
{
Lifetime = (int)a.ConstructorArguments[0].Value,
ForType = ((ITypeSymbol)a.ConstructorArguments[1].Value)?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) ?? typeSymbol.ClassName,
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name
},
"RegistInitializerAttribute" => new RegistDescriptor()
{
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name
},
"RegistViewAttribute" => new RegistDescriptor()
{
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name,
Name = typeSymbol.ClassName
},
_ => new RegistDescriptor(),
};
});
descriptors.AddRange(tmp);
}
return descriptors;
}
.Where(a => a.AttributeClass.ToDisplayString() == REGISTATTRIBUTE_NAME || a.AttributeClass.BaseType.ToDisplayString() == REGISTATTRIBUTE_NAME)
这一段是判断当前 Symbol
是否包含特定的 Attribute, 或者 或父类是否是特定的 Attribute.
因为 RegistAttribute
会有很多变形,如果完全解析的话, 难度和时间成本太大, 所以这里采取了一个折衷的办法, 即 switch 那一段, 跟据不同的名称,生成不同的数据。
除了 ConstructorArguments
之外, 还有 NamedArguments
, 但是这个东西的变形太大, 不方便在这里使用, 所以要求 RegistAttribute
要足够简单。
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistAttribute : Attribute/*, INamed*/
{
/// <summary>
/// 模式
/// </summary>
public RegistMode Mode { get; }
/// <summary>
/// 注册为哪个类型
/// </summary>
public Type? ForType { get; }
public RegistAttribute(RegistMode mode, Type? forType = null)
{
this.Mode = mode;
this.ForType = forType;
}
}
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistAttribute<T> : RegistAttribute
{
public RegistAttribute(RegistMode mode) : base(mode, typeof(T))
{
}
}
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistInitializerAttribute : RegistAttribute
{
/// <summary>
///
/// </summary>
public RegistInitializerAttribute() : base(RegistMode.Singleton, typeof(IInitializer))
{
}
}
生成代码:
public static string GetCode(RegistDescriptor descriptor)
{
if (string.IsNullOrWhiteSpace(descriptor.ForType))
return $"RegistService(sc, typeof({descriptor.TargetType}), \"{descriptor.Tag}\", \"{descriptor.Name}\");";
else
return $"RegistService(sc, typeof({descriptor.ForType}), typeof({descriptor.TargetType}), {descriptor.Lifetime});";
}
public static string GetCode(IEnumerable<RegistDescriptor> ds)
{
if (ds?.Any() != true)
ds = Enumerable.Empty<RegistDescriptor>();
var arr = ds.Select(GetCode).Distinct().OrderBy(s => s);
var str = string.Join("\r\n ", arr);
var code = $$"""
using CNB.PubTools.Common;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Diagnostics.CodeAnalysis;
namespace CNB.PubTools
{
internal partial class Regist
{
public static partial void RegistService(IServiceCollection sc)
{
{{str}}
}
}
}
""";
return code;
}
结合上面的 SourceGenerator , 需要另外定义一个 partial 的类:
/// <summary>
///
/// </summary>
internal partial class Regist
{
/// <summary>
///
/// </summary>
/// <param name="sc"></param>
/// <param name="serviceType"></param>
/// <param name="targetType"></param>
/// <param name="mode"></param>
private static void RegistService(IServiceCollection sc, Type serviceType, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, int mode)
{
switch ((RegistMode)mode)
{
case RegistMode.Scoped:
sc.AddScoped(serviceType, targetType);
break;
case RegistMode.Singleton:
sc.AddSingleton(serviceType, targetType);
break;
case RegistMode.PreRequest:
sc.AddTransient(serviceType, targetType);
break;
}
}
private static void RegistService(IServiceCollection sc, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, string tag, string? name)
{
switch (tag)
{
case "RegistInitializerAttribute":
sc.AddSingleton(typeof(IInitializer), targetType);
break;
case "RegistViewAttribute":
sc.TryAddKeyedTransient(typeof(IView), name, targetType);
break;
}
}
public static partial void RegistService(IServiceCollection sc);
}
然后,在主项目里引用这个 SourceGenerator
<ProjectReference Include="..\CNB.PubTools.SourceGenerator\CNB.PubTools.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
分析器项目, 不引用输出的 dll, OutputItemType="Analyzer"
ReferenceOutputAssembly="false"
最后, 你可以在项目文件中添加:
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
把生成的代码输出到 obj 目录下面。
或者在如下的地方找到 SourceGenerator 生成的文件: