diff --git a/copyright_test.go b/copyright_test.go index 287e5948..bda479fd 100644 --- a/copyright_test.go +++ b/copyright_test.go @@ -17,8 +17,7 @@ import ( func TestCopyrightHeaders(t *testing.T) { var re = regexp.MustCompile(`Copyright \d{4} The Go MCP SDK Authors. All rights reserved. -Use of this source code is governed by an MIT-style -license that can be found in the LICENSE file.`) +Use of this source code is governed by (the license\n|an MIT-style\nlicense )that can be found in the LICENSE file.`) err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { if err != nil { diff --git a/internal/jsonrpc2/messages.go b/internal/jsonrpc2/messages.go index 791e698d..6e58a6c9 100644 --- a/internal/jsonrpc2/messages.go +++ b/internal/jsonrpc2/messages.go @@ -5,9 +5,12 @@ package jsonrpc2 import ( + "bytes" "encoding/json" "errors" "fmt" + + "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" ) // ID is a Request identifier, which is defined by the spec to be a string, integer, or null. @@ -145,9 +148,9 @@ func toWireError(err error) *WireError { func EncodeMessage(msg Message) ([]byte, error) { wire := wireCombined{VersionTag: wireVersion} msg.marshal(&wire) - data, err := json.Marshal(&wire) + data, err := jsonMarshal(&wire) if err != nil { - return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + return nil, fmt.Errorf("marshaling jsonrpc message: %w", err) } return data, nil } @@ -158,11 +161,14 @@ func EncodeMessage(msg Message) ([]byte, error) { func EncodeIndent(msg Message, prefix, indent string) ([]byte, error) { wire := wireCombined{VersionTag: wireVersion} msg.marshal(&wire) - data, err := json.MarshalIndent(&wire, prefix, indent) - if err != nil { - return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + enc.SetIndent(prefix, indent) + if err := enc.Encode(&wire); err != nil { + return nil, fmt.Errorf("marshaling jsonrpc message: %w", err) } - return data, nil + return bytes.TrimRight(buf.Bytes(), "\n"), nil } func DecodeMessage(data []byte) (Message, error) { @@ -204,9 +210,31 @@ func marshalToRaw(obj any) (json.RawMessage, error) { if obj == nil { return nil, nil } - data, err := json.Marshal(obj) + data, err := jsonMarshal(obj) if err != nil { return nil, err } return json.RawMessage(data), nil } + +// jsonescaping is a compatibility parameter that allows to restore +// JSON escaping in the JSON marshaling, which stopped being the default +// in the 1.4.0 version of the SDK. See the documentation for the +// mcpgodebug package for instructions how to enable it. +// The option will be removed in the 1.6.0 version of the SDK. +var jsonescaping = mcpgodebug.Value("jsonescaping") + +// jsonMarshal marshals obj to JSON like json.Marshal but without HTML escaping. +func jsonMarshal(obj any) ([]byte, error) { + if jsonescaping == "1" { + return json.Marshal(obj) + } + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(obj); err != nil { + return nil, err + } + // json.Encoder.Encode adds a trailing newline. Trim it to be consistent with json.Marshal. + return bytes.TrimRight(buf.Bytes(), "\n"), nil +} diff --git a/internal/mcpgodebug/mcpgodebug.go b/internal/mcpgodebug/mcpgodebug.go new file mode 100644 index 00000000..7f8f7ca3 --- /dev/null +++ b/internal/mcpgodebug/mcpgodebug.go @@ -0,0 +1,52 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +// Package mcpgodebug provides a mechanism to configure compatibility parameters +// via the MCPGODEBUG environment variable. +// +// The value of MCPGODEBUG is a comma-separated list of key=value pairs. +// For example: +// +// MCPGODEBUG=someoption=1,otheroption=value +package mcpgodebug + +import ( + "fmt" + "os" + "strings" +) + +const compatibilityEnvKey = "MCPGODEBUG" + +var compatibilityParams map[string]string + +func init() { + var err error + compatibilityParams, err = parseCompatibility(os.Getenv(compatibilityEnvKey)) + if err != nil { + panic(err) + } +} + +// Value returns the value of the compatibility parameter with the given key. +// It returns an empty string if the key is not set. +func Value(key string) string { + return compatibilityParams[key] +} + +func parseCompatibility(envValue string) (map[string]string, error) { + if envValue == "" { + return nil, nil + } + + params := make(map[string]string) + for part := range strings.SplitSeq(envValue, ",") { + k, v, ok := strings.Cut(part, "=") + if !ok { + return nil, fmt.Errorf("MCPGODEBUG: invalid format: %q", part) + } + params[strings.TrimSpace(k)] = strings.TrimSpace(v) + } + return params, nil +} diff --git a/internal/mcpgodebug/mcpgodebug_test.go b/internal/mcpgodebug/mcpgodebug_test.go new file mode 100644 index 00000000..9bbbc320 --- /dev/null +++ b/internal/mcpgodebug/mcpgodebug_test.go @@ -0,0 +1,88 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcpgodebug + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestParseCompatibility_Success(t *testing.T) { + tests := []struct { + name string + envVal string + want map[string]string + }{ + { + name: "Basic", + envVal: "foo=bar,baz=qux", + want: map[string]string{ + "foo": "bar", + "baz": "qux", + }, + }, + { + name: "Empty", + envVal: "", + want: nil, + }, + { + name: "WithWhitespace", + envVal: " foo = bar \t, baz = qux ", + want: map[string]string{ + "foo": "bar", + "baz": "qux", + }, + }, + { + name: "WithEqualsSignInValue", + envVal: "foo=bar=baz", + want: map[string]string{ + "foo": "bar=baz", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseCompatibility(tt.envVal) + if err != nil { + t.Fatalf("parseCompatibility() failed: %v", err) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("parseCompatibility() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestParseCompatibility_Failure(t *testing.T) { + tests := []struct { + name string + envVal string + }{ + { + name: "NoEqualsSign", + envVal: "invalidformat", + }, + { + name: "MixedValidAndInvalid", + envVal: "foo=bar,baz", + }, + { + name: "EmptyPart", + envVal: "foo=bar,,baz=qux", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parseCompatibility(tt.envVal) + if err == nil { + t.Error("parseCompatibility() expected error, got nil") + } + }) + } +}