diff --git a/src/ModelContextProtocol.Analyzers/CS1066Suppressor.cs b/src/ModelContextProtocol.Analyzers/CS1066Suppressor.cs new file mode 100644 index 00000000..ff8cdcc3 --- /dev/null +++ b/src/ModelContextProtocol.Analyzers/CS1066Suppressor.cs @@ -0,0 +1,148 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Threading; + +namespace ModelContextProtocol.Analyzers; + +/// +/// Suppresses CS1066 warnings for MCP server methods that have optional parameters. +/// +/// +/// +/// CS1066 is issued when a partial method's implementing declaration has default parameter values. +/// For partial methods, only the defining declaration's defaults are used by callers, +/// making the implementing declaration's defaults redundant. +/// +/// +/// However, for MCP tool, prompt, and resource methods, users often want to specify default values +/// in their implementing declaration for documentation purposes. The XmlToDescriptionGenerator +/// automatically copies these defaults to the generated defining declaration, making them functional. +/// +/// +/// This suppressor suppresses CS1066 for methods marked with [McpServerTool], [McpServerPrompt], +/// or [McpServerResource] attributes, allowing users to specify defaults in their code without warnings. +/// +/// +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public sealed class CS1066Suppressor : DiagnosticSuppressor +{ + private static readonly SuppressionDescriptor McpToolSuppression = new( + id: "MCP_CS1066_TOOL", + suppressedDiagnosticId: "CS1066", + justification: "Default values on MCP tool method implementing declarations are copied to the generated defining declaration by the source generator."); + + private static readonly SuppressionDescriptor McpPromptSuppression = new( + id: "MCP_CS1066_PROMPT", + suppressedDiagnosticId: "CS1066", + justification: "Default values on MCP prompt method implementing declarations are copied to the generated defining declaration by the source generator."); + + private static readonly SuppressionDescriptor McpResourceSuppression = new( + id: "MCP_CS1066_RESOURCE", + suppressedDiagnosticId: "CS1066", + justification: "Default values on MCP resource method implementing declarations are copied to the generated defining declaration by the source generator."); + + /// + public override ImmutableArray SupportedSuppressions => + ImmutableArray.Create(McpToolSuppression, McpPromptSuppression, McpResourceSuppression); + + /// + public override void ReportSuppressions(SuppressionAnalysisContext context) + { + // Cache semantic models and attribute symbols per syntax tree/compilation to avoid redundant calls + Dictionary? semanticModelCache = null; + INamedTypeSymbol? mcpToolAttribute = null; + INamedTypeSymbol? mcpPromptAttribute = null; + INamedTypeSymbol? mcpResourceAttribute = null; + bool attributesResolved = false; + + foreach (Diagnostic diagnostic in context.ReportedDiagnostics) + { + Location? location = diagnostic.Location; + SyntaxTree? tree = location.SourceTree; + if (tree is null) + { + continue; + } + + SyntaxNode root = tree.GetRoot(context.CancellationToken); + SyntaxNode? node = root.FindNode(location.SourceSpan); + + // Find the containing method declaration + MethodDeclarationSyntax? method = node.FirstAncestorOrSelf(); + if (method is null) + { + continue; + } + + // Get or cache the semantic model for this tree + semanticModelCache ??= new Dictionary(); + if (!semanticModelCache.TryGetValue(tree, out SemanticModel? semanticModel)) + { + semanticModel = context.GetSemanticModel(tree); + semanticModelCache[tree] = semanticModel; + } + + // Resolve attribute symbols once per compilation + if (!attributesResolved) + { + mcpToolAttribute = semanticModel.Compilation.GetTypeByMetadataName(McpAttributeNames.McpServerToolAttribute); + mcpPromptAttribute = semanticModel.Compilation.GetTypeByMetadataName(McpAttributeNames.McpServerPromptAttribute); + mcpResourceAttribute = semanticModel.Compilation.GetTypeByMetadataName(McpAttributeNames.McpServerResourceAttribute); + attributesResolved = true; + } + + // Check for MCP attributes + SuppressionDescriptor? suppression = GetSuppressionForMethod(method, semanticModel, mcpToolAttribute, mcpPromptAttribute, mcpResourceAttribute, context.CancellationToken); + if (suppression is not null) + { + context.ReportSuppression(Suppression.Create(suppression, diagnostic)); + } + } + } + + private static SuppressionDescriptor? GetSuppressionForMethod( + MethodDeclarationSyntax method, + SemanticModel semanticModel, + INamedTypeSymbol? mcpToolAttribute, + INamedTypeSymbol? mcpPromptAttribute, + INamedTypeSymbol? mcpResourceAttribute, + CancellationToken cancellationToken) + { + IMethodSymbol? methodSymbol = semanticModel.GetDeclaredSymbol(method, cancellationToken); + + if (methodSymbol is null) + { + return null; + } + + foreach (AttributeData attribute in methodSymbol.GetAttributes()) + { + INamedTypeSymbol? attributeClass = attribute.AttributeClass; + if (attributeClass is null) + { + continue; + } + + if (mcpToolAttribute is not null && SymbolEqualityComparer.Default.Equals(attributeClass, mcpToolAttribute)) + { + return McpToolSuppression; + } + + if (mcpPromptAttribute is not null && SymbolEqualityComparer.Default.Equals(attributeClass, mcpPromptAttribute)) + { + return McpPromptSuppression; + } + + if (mcpResourceAttribute is not null && SymbolEqualityComparer.Default.Equals(attributeClass, mcpResourceAttribute)) + { + return McpResourceSuppression; + } + } + + return null; + } +} diff --git a/src/ModelContextProtocol.Analyzers/McpAttributeNames.cs b/src/ModelContextProtocol.Analyzers/McpAttributeNames.cs new file mode 100644 index 00000000..f615d07d --- /dev/null +++ b/src/ModelContextProtocol.Analyzers/McpAttributeNames.cs @@ -0,0 +1,12 @@ +namespace ModelContextProtocol.Analyzers; + +/// +/// Contains the fully qualified metadata names for MCP server attributes. +/// +internal static class McpAttributeNames +{ + public const string McpServerToolAttribute = "ModelContextProtocol.Server.McpServerToolAttribute"; + public const string McpServerPromptAttribute = "ModelContextProtocol.Server.McpServerPromptAttribute"; + public const string McpServerResourceAttribute = "ModelContextProtocol.Server.McpServerResourceAttribute"; + public const string DescriptionAttribute = "System.ComponentModel.DescriptionAttribute"; +} diff --git a/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs b/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs index 992a8394..71982a3f 100644 --- a/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs +++ b/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs @@ -18,18 +18,14 @@ namespace ModelContextProtocol.Analyzers; public sealed class XmlToDescriptionGenerator : IIncrementalGenerator { private const string GeneratedFileName = "ModelContextProtocol.Descriptions.g.cs"; - private const string McpServerToolAttributeName = "ModelContextProtocol.Server.McpServerToolAttribute"; - private const string McpServerPromptAttributeName = "ModelContextProtocol.Server.McpServerPromptAttribute"; - private const string McpServerResourceAttributeName = "ModelContextProtocol.Server.McpServerResourceAttribute"; - private const string DescriptionAttributeName = "System.ComponentModel.DescriptionAttribute"; public void Initialize(IncrementalGeneratorInitializationContext context) { // Extract method information for all MCP tools, prompts, and resources. // The transform extracts all necessary data upfront so the output doesn't depend on the compilation. - var allMethods = CreateProviderForAttribute(context, McpServerToolAttributeName).Collect() - .Combine(CreateProviderForAttribute(context, McpServerPromptAttributeName).Collect()) - .Combine(CreateProviderForAttribute(context, McpServerResourceAttributeName).Collect()) + var allMethods = CreateProviderForAttribute(context, McpAttributeNames.McpServerToolAttribute).Collect() + .Combine(CreateProviderForAttribute(context, McpAttributeNames.McpServerPromptAttribute).Collect()) + .Combine(CreateProviderForAttribute(context, McpAttributeNames.McpServerResourceAttribute).Collect()) .Select(static (tuple, _) => { var ((tools, prompts), resources) = tuple; @@ -84,7 +80,7 @@ private static MethodToGenerate ExtractMethodInfo( Compilation compilation) { bool isPartial = methodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword); - var descriptionAttribute = compilation.GetTypeByMetadataName(DescriptionAttributeName); + var descriptionAttribute = compilation.GetTypeByMetadataName(McpAttributeNames.DescriptionAttribute); // Try to extract XML documentation var (xmlDocs, hasInvalidXml) = TryExtractXmlDocumentation(methodSymbol); diff --git a/tests/ModelContextProtocol.Analyzers.Tests/CS1066SuppressorTests.cs b/tests/ModelContextProtocol.Analyzers.Tests/CS1066SuppressorTests.cs new file mode 100644 index 00000000..57eaaf41 --- /dev/null +++ b/tests/ModelContextProtocol.Analyzers.Tests/CS1066SuppressorTests.cs @@ -0,0 +1,232 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using System.Collections.Immutable; +using Xunit; + +namespace ModelContextProtocol.Analyzers.Tests; + +public class CS1066SuppressorTests +{ + [Fact] + public void Suppressor_WithMcpServerToolAttribute_SuppressesCS1066() + { + var result = RunSuppressor(""" + using ModelContextProtocol.Server; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + [McpServerTool] + public partial string TestMethod(string input = "default"); + } + + public partial class TestTools + { + public partial string TestMethod(string input = "default") + { + return input; + } + } + """); + + // Check we have the CS1066 diagnostics from compiler + var cs1066FromCompiler = result.CompilerDiagnostics.Where(d => d.Id == "CS1066").ToList(); + + // CS1066 should be suppressed in the final diagnostics + var unsuppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && !d.IsSuppressed).ToList(); + var suppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && d.IsSuppressed).ToList(); + + Assert.True(cs1066FromCompiler.Count > 0 || suppressedCs1066.Count > 0, + $"Expected CS1066 diagnostics. Compiler diagnostics: {string.Join(", ", result.CompilerDiagnostics.Select(d => d.Id))}"); + Assert.Empty(unsuppressedCs1066); + } + + [Fact] + public void Suppressor_WithMcpServerPromptAttribute_SuppressesCS1066() + { + var result = RunSuppressor(""" + using ModelContextProtocol.Server; + + namespace Test; + + [McpServerPromptType] + public partial class TestPrompts + { + [McpServerPrompt] + public partial string TestPrompt(string input = "default"); + } + + public partial class TestPrompts + { + public partial string TestPrompt(string input = "default") + { + return input; + } + } + """); + + // Check we have the CS1066 diagnostics from compiler + var cs1066FromCompiler = result.CompilerDiagnostics.Where(d => d.Id == "CS1066").ToList(); + + // CS1066 should be suppressed in the final diagnostics + var unsuppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && !d.IsSuppressed).ToList(); + var suppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && d.IsSuppressed).ToList(); + + Assert.True(cs1066FromCompiler.Count > 0 || suppressedCs1066.Count > 0, + $"Expected CS1066 diagnostics. Compiler diagnostics: {string.Join(", ", result.CompilerDiagnostics.Select(d => d.Id))}"); + Assert.Empty(unsuppressedCs1066); + } + + [Fact] + public void Suppressor_WithMcpServerResourceAttribute_SuppressesCS1066() + { + var result = RunSuppressor(""" + using ModelContextProtocol.Server; + + namespace Test; + + [McpServerResourceType] + public partial class TestResources + { + [McpServerResource("test://resource")] + public partial string TestResource(string input = "default"); + } + + public partial class TestResources + { + public partial string TestResource(string input = "default") + { + return input; + } + } + """); + + // Check we have the CS1066 diagnostics from compiler + var cs1066FromCompiler = result.CompilerDiagnostics.Where(d => d.Id == "CS1066").ToList(); + + // CS1066 should be suppressed in the final diagnostics + var unsuppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && !d.IsSuppressed).ToList(); + var suppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && d.IsSuppressed).ToList(); + + Assert.True(cs1066FromCompiler.Count > 0 || suppressedCs1066.Count > 0, + $"Expected CS1066 diagnostics. Compiler diagnostics: {string.Join(", ", result.CompilerDiagnostics.Select(d => d.Id))}"); + Assert.Empty(unsuppressedCs1066); + } + + [Fact] + public void Suppressor_WithoutMcpAttribute_DoesNotSuppressCS1066() + { + var result = RunSuppressor(""" + namespace Test; + + public partial class TestTools + { + public partial string TestMethod(string input = "default"); + } + + public partial class TestTools + { + public partial string TestMethod(string input = "default") + { + return input; + } + } + """); + + // CS1066 should NOT be suppressed (no MCP attribute) + // Check we have the CS1066 diagnostic from compiler + var cs1066FromCompiler = result.CompilerDiagnostics.Where(d => d.Id == "CS1066").ToList(); + Assert.NotEmpty(cs1066FromCompiler); + + // It should NOT be suppressed in the final diagnostics (still present as unsuppressed) + var unsuppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && !d.IsSuppressed).ToList(); + Assert.NotEmpty(unsuppressedCs1066); + Assert.DoesNotContain(result.Diagnostics, d => d.Id == "CS1066" && d.IsSuppressed); + } + + [Fact] + public void Suppressor_WithMultipleParameters_SuppressesAllCS1066() + { + var result = RunSuppressor(""" + using ModelContextProtocol.Server; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + [McpServerTool] + public partial string TestMethod(string input = "default", int count = 42, bool flag = false); + } + + public partial class TestTools + { + public partial string TestMethod(string input = "default", int count = 42, bool flag = false) + { + return input; + } + } + """); + + // Check we have CS1066 diagnostics from compiler (one per parameter with default) + var cs1066FromCompiler = result.CompilerDiagnostics.Where(d => d.Id == "CS1066").ToList(); + Assert.Equal(3, cs1066FromCompiler.Count); // Three parameters with defaults + + // All CS1066 warnings should be suppressed + var unsuppressedCs1066 = result.Diagnostics.Where(d => d.Id == "CS1066" && !d.IsSuppressed).ToList(); + Assert.Empty(unsuppressedCs1066); + } + + private SuppressorResult RunSuppressor(string source) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + + // Get reference assemblies + List referenceList = + [ + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(typeof(System.ComponentModel.DescriptionAttribute).Assembly.Location), + ]; + + // Add all necessary runtime assemblies + var runtimePath = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + referenceList.Add(MetadataReference.CreateFromFile(Path.Combine(runtimePath, "System.Runtime.dll"))); + referenceList.Add(MetadataReference.CreateFromFile(Path.Combine(runtimePath, "netstandard.dll"))); + + // Add ModelContextProtocol.Core if available + var coreAssemblyPath = Path.Combine(AppContext.BaseDirectory, "ModelContextProtocol.Core.dll"); + if (File.Exists(coreAssemblyPath)) + { + referenceList.Add(MetadataReference.CreateFromFile(coreAssemblyPath)); + } + + var compilation = CSharpCompilation.Create( + "TestAssembly", + [syntaxTree], + referenceList, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + // Get compilation diagnostics first (includes CS1066) + var compilerDiagnostics = compilation.GetDiagnostics(); + + // Run the suppressor + var analyzers = ImmutableArray.Create(new CS1066Suppressor()); + var compilationWithAnalyzers = compilation.WithAnalyzers(analyzers); + var allDiagnostics = compilationWithAnalyzers.GetAllDiagnosticsAsync().GetAwaiter().GetResult(); + + return new SuppressorResult + { + Diagnostics = allDiagnostics.ToList(), + CompilerDiagnostics = compilerDiagnostics.ToList() + }; + } + + private class SuppressorResult + { + public List Diagnostics { get; set; } = []; + public List CompilerDiagnostics { get; set; } = []; + } +} diff --git a/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs b/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs index 0feacd0b..b2cf8365 100644 --- a/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs +++ b/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs @@ -1,5 +1,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using Xunit; @@ -347,7 +349,7 @@ public static string TestMethod(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -374,7 +376,7 @@ public static string TestMethod(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -405,7 +407,7 @@ public static string TestMethod(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -434,7 +436,7 @@ public static string TestMethod(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -552,7 +554,7 @@ public static string TestMethod(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -583,7 +585,7 @@ public static string TestPrompt(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -614,7 +616,7 @@ public static string TestResource(string input) return input; } } - """); + """, "MCP002"); Assert.True(result.Success); Assert.Empty(result.GeneratedSources); @@ -695,7 +697,7 @@ public static partial string TestInvalidXml(string input) return input; } } - """); + """, "MCP001"); // Should not throw, generates partial implementation without Description attributes Assert.True(result.Success); @@ -1717,7 +1719,7 @@ partial class TestTools AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); } - private GeneratorRunResult RunGenerator([StringSyntax("C#-test")] string source) + private GeneratorRunResult RunGenerator([StringSyntax("C#-test")] string source, params string[] expectedDiagnosticIds) { var syntaxTree = CSharpSyntaxTree.ParseText(source); @@ -1755,15 +1757,34 @@ private GeneratorRunResult RunGenerator([StringSyntax("C#-test")] string source) var driver = (CSharpGeneratorDriver)CSharpGeneratorDriver .Create(new XmlToDescriptionGenerator()) - .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics); + .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var generatorDiagnostics); var runResult = driver.GetRunResult(); + // Run the suppressor to check that CS1066 warnings for MCP methods are suppressed + var analyzers = ImmutableArray.Create(new CS1066Suppressor()); + var compilationWithAnalyzers = outputCompilation.WithAnalyzers(analyzers); + var allDiagnostics = compilationWithAnalyzers.GetAllDiagnosticsAsync().GetAwaiter().GetResult(); + + // Check for any unsuppressed CS1066 warnings - these should be suppressed by our suppressor + var unsuppressedCs1066 = allDiagnostics + .Where(d => d.Id == "CS1066" && !d.IsSuppressed) + .ToList(); + + // Collect all diagnostics from the generator (any verbosity level) + var allGeneratorDiagnostics = generatorDiagnostics.Concat(unsuppressedCs1066).ToList(); + + // Check for unexpected diagnostics - any diagnostic that isn't in the expected list + var expectedSet = new HashSet(expectedDiagnosticIds); + var unexpectedDiagnostics = allGeneratorDiagnostics + .Where(d => !expectedSet.Contains(d.Id)) + .ToList(); + return new GeneratorRunResult { - Success = !diagnostics.Any(d => d.Severity == DiagnosticSeverity.Error), + Success = unexpectedDiagnostics.Count == 0, GeneratedSources = runResult.GeneratedTrees.Select(t => (t.FilePath, t.GetText())).ToList(), - Diagnostics = diagnostics.ToList(), + Diagnostics = allGeneratorDiagnostics, Compilation = outputCompilation }; }