|
using System.Diagnostics; |
|
using System.Text; |
|
using Microsoft.CodeAnalysis; |
|
using Microsoft.CodeAnalysis.CSharp; |
|
using Microsoft.CodeAnalysis.CSharp.Syntax; |
|
|
|
namespace Terminal.Shell.Analyzers; |
|
|
|
[Generator(LanguageNames.CSharp)] |
|
public class ComponentGenerator : IIncrementalGenerator |
|
{ |
|
public void Initialize(IncrementalGeneratorInitializationContext context) |
|
{ |
|
var attribute = context.CompilationProvider |
|
.Select((x, c) => x.GetTypeByMetadataName("Terminal.Shell.ComponentAttribute")); |
|
|
|
var types = context.CompilationProvider.SelectMany((x, c) => |
|
{ |
|
var visitor = new TypesVisitor(s => |
|
// Must be declared in the current assembly |
|
s.ContainingAssembly.Equals(x.Assembly, SymbolEqualityComparer.Default) && |
|
// And be a partial class |
|
s.DeclaringSyntaxReferences.All(r => |
|
r.GetSyntax(c) is ClassDeclarationSyntax syntax && |
|
syntax.Modifiers.Any(t => t.IsKind(SyntaxKind.PartialKeyword))) && |
|
// And be accessible within the current assembly (i.e. not a private nested type) |
|
x.IsSymbolAccessibleWithin(s, x.Assembly), c); |
|
|
|
x.GlobalNamespace.Accept(visitor); |
|
|
|
return visitor.TypeSymbols; |
|
}); |
|
|
|
var components = types |
|
.Combine(attribute) |
|
.Where(x => x.Left.GetAttributes().Any(a => IsComponent(x.Right, a))) |
|
.Select((x, _) => new { Type = x.Left, Attributes = x.Left.GetAttributes().Where(a => IsComponent(x.Right, a)).ToList() }); |
|
|
|
context.RegisterSourceOutput(components, (ctx, data) => AddPartial(ctx, data.Type, data.Attributes)); |
|
} |
|
|
|
void AddPartial(SourceProductionContext ctx, INamedTypeSymbol type, List<AttributeData> attributes) |
|
{ |
|
var builder = new StringBuilder(); |
|
|
|
builder.AppendLine( |
|
""" |
|
// <auto-generated /> |
|
using System.Composition; |
|
|
|
"""); |
|
if (type.ContainingNamespace != null) |
|
builder.AppendLine($"namespace {type.ContainingNamespace.Name};").AppendLine(); |
|
|
|
var names = new HashSet<string>(); |
|
|
|
foreach (var attr in attributes |
|
.Where(x => x.AttributeConstructor != null) |
|
.SelectMany(x => x.AttributeConstructor!.Parameters.Select((p, i) => new { p.Name, Value = x.ConstructorArguments[i] })) |
|
.Concat(attributes |
|
.SelectMany(x => x.NamedArguments.Select(a => new { Name = a.Key, Value = a.Value })))) |
|
{ |
|
var name = attr.Name; |
|
if (char.IsLower(name[0])) |
|
name = new StringBuilder(name.Length).Append(char.ToUpper(name[0])).Append(name[1..]).ToString(); |
|
|
|
if (names.Contains(name)) |
|
continue; |
|
|
|
builder.AppendLine($"[ExportMetadata(\"{attr.Name}\", {attr.Value.ToCSharpString()})]"); |
|
names.Add(name); |
|
} |
|
|
|
builder.AppendLine( |
|
$$""" |
|
partial class {{type.Name}} { } |
|
"""); |
|
|
|
ctx.AddSource(type.Name + ".g", builder.ToString().Replace("\r\n", "\n").Replace("\n", Environment.NewLine)); |
|
} |
|
|
|
static bool IsComponent(INamedTypeSymbol? baseAttribute, AttributeData attribute) |
|
{ |
|
//Debugger.Launch(); |
|
if (baseAttribute == null) |
|
return false; |
|
|
|
var baseType = attribute.AttributeClass?.BaseType; |
|
while (baseType != null) |
|
{ |
|
if (baseType.Equals(baseAttribute, SymbolEqualityComparer.Default)) |
|
return true; |
|
|
|
baseType = attribute.AttributeClass?.BaseType; |
|
} |
|
|
|
return false; |
|
} |
|
|
|
class TypesVisitor : SymbolVisitor |
|
{ |
|
Func<ISymbol, bool> shouldInclude; |
|
CancellationToken cancellation; |
|
HashSet<INamedTypeSymbol> types = new(SymbolEqualityComparer.Default); |
|
|
|
public TypesVisitor(Func<ISymbol, bool> shouldInclude, CancellationToken cancellation) |
|
{ |
|
this.shouldInclude = shouldInclude; |
|
this.cancellation = cancellation; |
|
} |
|
|
|
public HashSet<INamedTypeSymbol> TypeSymbols => types; |
|
|
|
public override void VisitAssembly(IAssemblySymbol symbol) |
|
{ |
|
cancellation.ThrowIfCancellationRequested(); |
|
symbol.GlobalNamespace.Accept(this); |
|
} |
|
|
|
public override void VisitNamespace(INamespaceSymbol symbol) |
|
{ |
|
foreach (var namespaceOrType in symbol.GetMembers()) |
|
{ |
|
cancellation.ThrowIfCancellationRequested(); |
|
namespaceOrType.Accept(this); |
|
} |
|
} |
|
|
|
public override void VisitNamedType(INamedTypeSymbol type) |
|
{ |
|
cancellation.ThrowIfCancellationRequested(); |
|
|
|
if (!shouldInclude(type) || !types.Add(type)) |
|
return; |
|
|
|
var nestedTypes = type.GetTypeMembers(); |
|
if (nestedTypes.IsDefaultOrEmpty) |
|
return; |
|
|
|
foreach (var nestedType in nestedTypes) |
|
{ |
|
cancellation.ThrowIfCancellationRequested(); |
|
nestedType.Accept(this); |
|
} |
|
} |
|
} |
|
} |