diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 92003e6d7..30986b875 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -136,18 +136,18 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (context.Context, error) { ctx := r.Context() - clientID, clientSecret, err := oauthserver.ExtractClientCredentials(r) + creds, err := oauthserver.ExtractClientCredentials(r) if err != nil { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials: %s", err.Error()) } // If no client credentials provided, continue without client authentication - if clientID == "" { + if creds.ClientID == "" { return ctx, nil } // Parse client_id as UUID - clientUUID, err := uuid.FromString(clientID) + clientUUID, err := uuid.FromString(creds.ClientID) if err != nil { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client_id format") } @@ -162,8 +162,13 @@ func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (co return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err) } - // Validate authentication using centralized logic - if err := oauthserver.ValidateClientAuthentication(client, clientSecret); err != nil { + // Validate that the auth method used matches the client's registered method + if err := oauthserver.ValidateClientAuthMethod(client, creds.AuthMethod); err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "%s", err.Error()) + } + + // Validate authentication using centralized logic (secret verification) + if err := oauthserver.ValidateClientAuthentication(client, creds.ClientSecret); err != nil { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "%s", err.Error()) } diff --git a/internal/api/oauthserver/auth.go b/internal/api/oauthserver/auth.go index b7fbb8597..1c53df71d 100644 --- a/internal/api/oauthserver/auth.go +++ b/internal/api/oauthserver/auth.go @@ -8,27 +8,41 @@ import ( "io" "net/http" "strings" + + "github.com/supabase/auth/internal/models" ) +// ClientCredentials represents the extracted client credentials and authentication method used +type ClientCredentials struct { + ClientID string + ClientSecret string + AuthMethod string +} + // ExtractClientCredentials extracts OAuth client credentials from the request // Supports Basic auth header, form body parameters, and JSON body parameters -func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, err error) { +func ExtractClientCredentials(r *http.Request) (*ClientCredentials, error) { + creds := &ClientCredentials{} + // First, try Basic auth header: Authorization: Basic base64(client_id:client_secret) authHeader := r.Header.Get("Authorization") if authHeader != "" && strings.HasPrefix(authHeader, "Basic ") { encoded := strings.TrimPrefix(authHeader, "Basic ") decoded, err := base64.StdEncoding.DecodeString(encoded) if err != nil { - return "", "", errors.New("invalid basic auth encoding") + return nil, errors.New("invalid basic auth encoding") } credentials := string(decoded) parts := strings.SplitN(credentials, ":", 2) if len(parts) != 2 { - return "", "", errors.New("invalid basic auth format") + return nil, errors.New("invalid basic auth format") } - return parts[0], parts[1], nil + creds.ClientID = parts[0] + creds.ClientSecret = parts[1] + creds.AuthMethod = models.TokenEndpointAuthMethodClientSecretBasic + return creds, nil } // Check Content-Type to determine how to parse body parameters @@ -37,7 +51,7 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e // Parse JSON body body, err := io.ReadAll(r.Body) if err != nil { - return "", "", errors.New("failed to read request body") + return nil, errors.New("failed to read request body") } // Restore the body so other handlers can read it r.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -47,25 +61,32 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e ClientSecret string `json:"client_secret"` } if err := json.Unmarshal(body, &jsonData); err != nil { - return "", "", errors.New("failed to parse JSON body") + return nil, errors.New("failed to parse JSON body") } - clientID = jsonData.ClientID - clientSecret = jsonData.ClientSecret + creds.ClientID = jsonData.ClientID + creds.ClientSecret = jsonData.ClientSecret } else { // Fall back to form parameters if err := r.ParseForm(); err != nil { - return "", "", errors.New("failed to parse form") + return nil, errors.New("failed to parse form") } - clientID = r.FormValue("client_id") - clientSecret = r.FormValue("client_secret") + creds.ClientID = r.FormValue("client_id") + creds.ClientSecret = r.FormValue("client_secret") } // return error if client_id is not provided - if clientID == "" { - return "", "", errors.New("client_id is required") + if creds.ClientID == "" { + return nil, errors.New("client_id is required") + } + + // Determine auth method based on presence of client_secret in body + if creds.ClientSecret != "" { + creds.AuthMethod = models.TokenEndpointAuthMethodClientSecretPost + } else { + creds.AuthMethod = models.TokenEndpointAuthMethodNone } - return clientID, clientSecret, nil + return creds, nil } diff --git a/internal/api/oauthserver/client_auth.go b/internal/api/oauthserver/client_auth.go index 6ed505e84..a911d983a 100644 --- a/internal/api/oauthserver/client_auth.go +++ b/internal/api/oauthserver/client_auth.go @@ -108,3 +108,15 @@ func GetAllValidAuthMethods() []string { models.TokenEndpointAuthMethodClientSecretPost, } } + +// ValidateClientAuthMethod validates the authentication method used matches the registered method +func ValidateClientAuthMethod(client *models.OAuthServerClient, usedMethod string) error { + registeredMethod := client.GetTokenEndpointAuthMethod() + + if usedMethod != registeredMethod { + return fmt.Errorf("invalid authentication method: client is registered for '%s' but '%s' was used", + registeredMethod, usedMethod) + } + + return nil +} diff --git a/internal/api/oauthserver/client_auth_test.go b/internal/api/oauthserver/client_auth_test.go index 40fec9c40..6ab8af91a 100644 --- a/internal/api/oauthserver/client_auth_test.go +++ b/internal/api/oauthserver/client_auth_test.go @@ -395,6 +395,100 @@ func TestGetAllValidAuthMethods(t *testing.T) { } } +func TestValidateClientAuthMethod(t *testing.T) { + tests := []struct { + name string + client *models.OAuthServerClient + usedMethod string + expectError bool + errorContains string + }{ + { + name: "client registered for basic should accept basic", + client: &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientType: models.OAuthServerClientTypeConfidential, + TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretBasic, + }, + usedMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expectError: false, + }, + { + name: "client registered for post should accept post", + client: &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientType: models.OAuthServerClientTypeConfidential, + TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretPost, + }, + usedMethod: models.TokenEndpointAuthMethodClientSecretPost, + expectError: false, + }, + { + name: "client registered for basic should reject post", + client: &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientType: models.OAuthServerClientTypeConfidential, + TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretBasic, + }, + usedMethod: models.TokenEndpointAuthMethodClientSecretPost, + expectError: true, + errorContains: "invalid authentication method", + }, + { + name: "client registered for post should reject basic", + client: &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientType: models.OAuthServerClientTypeConfidential, + TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretPost, + }, + usedMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expectError: true, + errorContains: "invalid authentication method", + }, + { + name: "public client registered for none should accept none", + client: &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientType: models.OAuthServerClientTypePublic, + TokenEndpointAuthMethod: models.TokenEndpointAuthMethodNone, + }, + usedMethod: models.TokenEndpointAuthMethodNone, + expectError: false, + }, + { + name: "public client registered for none should reject basic", + client: &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientType: models.OAuthServerClientTypePublic, + TokenEndpointAuthMethod: models.TokenEndpointAuthMethodNone, + }, + usedMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expectError: true, + errorContains: "invalid authentication method", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateClientAuthMethod(tt.client, tt.usedMethod) + + if tt.expectError { + if err == nil { + t.Errorf("ValidateClientAuthMethod() expected error but got nil") + return + } + if tt.errorContains != "" && !containsString(err.Error(), tt.errorContains) { + t.Errorf("ValidateClientAuthMethod() error = %v, expected to contain %v", err, tt.errorContains) + } + } else { + if err != nil { + t.Errorf("ValidateClientAuthMethod() expected no error but got: %v", err) + } + } + }) + } +} + // Helper function to check if a string contains a substring func containsString(s, substr string) bool { return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) && diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index c61de4085..ec844c85c 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -51,24 +51,13 @@ type OAuthServerClientListResponse struct { // oauthServerClientToResponse converts a model to response format func oauthServerClientToResponse(client *models.OAuthServerClient) *OAuthServerClientResponse { - // Set token endpoint auth methods based on client type - var tokenEndpointAuthMethods string - // TODO(cemal) :: Remove this once we have the token endpoint auth method stored in the database - if client.IsPublic() { - // Public clients don't use client authentication - tokenEndpointAuthMethods = models.TokenEndpointAuthMethodNone - } else { - // Confidential clients use client secret authentication - tokenEndpointAuthMethods = models.TokenEndpointAuthMethodClientSecretBasic - } - response := &OAuthServerClientResponse{ ClientID: client.ID.String(), ClientType: client.ClientType, // OAuth 2.1 DCR fields RedirectURIs: client.GetRedirectURIs(), - TokenEndpointAuthMethod: tokenEndpointAuthMethods, + TokenEndpointAuthMethod: client.GetTokenEndpointAuthMethod(), GrantTypes: client.GetGrantTypes(), ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1 ClientName: utilities.StringValue(client.ClientName), diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index dff27c774..7df96c547 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -263,15 +263,29 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer // Determine client type using centralized logic clientType := DetermineClientType(params.ClientType, params.TokenEndpointAuthMethod) + // Determine token_endpoint_auth_method + // If explicitly provided, use it; otherwise set default based on client type + // Per RFC 7591: "If unspecified or omitted, the default is 'client_secret_basic'" + // For public clients, the default is 'none' since they don't have a client secret + tokenEndpointAuthMethod := params.TokenEndpointAuthMethod + if tokenEndpointAuthMethod == "" { + if clientType == models.OAuthServerClientTypePublic { + tokenEndpointAuthMethod = models.TokenEndpointAuthMethodNone + } else { + tokenEndpointAuthMethod = models.TokenEndpointAuthMethodClientSecretBasic + } + } + db := s.db.WithContext(ctx) client := &models.OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - RegistrationType: params.RegistrationType, - ClientType: clientType, - ClientName: utilities.StringPtr(params.ClientName), - ClientURI: utilities.StringPtr(params.ClientURI), - LogoURI: utilities.StringPtr(params.LogoURI), + ID: uuid.Must(uuid.NewV4()), + RegistrationType: params.RegistrationType, + ClientType: clientType, + TokenEndpointAuthMethod: tokenEndpointAuthMethod, + ClientName: utilities.StringPtr(params.ClientName), + ClientURI: utilities.StringPtr(params.ClientURI), + LogoURI: utilities.StringPtr(params.LogoURI), } client.SetRedirectURIs(params.RedirectURIs) diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index c53a7ee62..3dc5dc893 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "net/url" + "slices" "strings" "time" @@ -28,10 +29,11 @@ const ( // OAuthServerClient represents an OAuth client application registered with this OAuth server type OAuthServerClient struct { - ID uuid.UUID `json:"client_id" db:"id"` - ClientSecretHash string `json:"-" db:"client_secret_hash"` - RegistrationType string `json:"registration_type" db:"registration_type"` - ClientType string `json:"client_type" db:"client_type"` + ID uuid.UUID `json:"client_id" db:"id"` + ClientSecretHash string `json:"-" db:"client_secret_hash"` + RegistrationType string `json:"registration_type" db:"registration_type"` + ClientType string `json:"client_type" db:"client_type"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method" db:"token_endpoint_auth_method"` RedirectURIs string `json:"-" db:"redirect_uris"` GrantTypes string `json:"grant_types" db:"grant_types"` @@ -82,6 +84,34 @@ func (c *OAuthServerClient) Validate() error { return fmt.Errorf("client_secret is not allowed for public clients, use PKCE instead") } + // Apply default token_endpoint_auth_method per RFC 7591: + // "If unspecified or omitted, the default is 'client_secret_basic'" + // For public clients, the default is 'none' since they don't have a client secret + if c.TokenEndpointAuthMethod == "" { + if c.ClientType == OAuthServerClientTypePublic { + c.TokenEndpointAuthMethod = TokenEndpointAuthMethodNone + } else { + c.TokenEndpointAuthMethod = TokenEndpointAuthMethodClientSecretBasic + } + } + + // Validate token_endpoint_auth_method + validMethods := []string{TokenEndpointAuthMethodNone, TokenEndpointAuthMethodClientSecretBasic, TokenEndpointAuthMethodClientSecretPost} + if !slices.Contains(validMethods, c.TokenEndpointAuthMethod) { + return fmt.Errorf("token_endpoint_auth_method must be one of: %s, %s, %s", + TokenEndpointAuthMethodNone, TokenEndpointAuthMethodClientSecretBasic, TokenEndpointAuthMethodClientSecretPost) + } + + // Public clients must use 'none' + if c.ClientType == OAuthServerClientTypePublic && c.TokenEndpointAuthMethod != TokenEndpointAuthMethodNone { + return fmt.Errorf("public clients must use token_endpoint_auth_method '%s'", TokenEndpointAuthMethodNone) + } + + // Confidential clients cannot use 'none' + if c.ClientType == OAuthServerClientTypeConfidential && c.TokenEndpointAuthMethod == TokenEndpointAuthMethodNone { + return fmt.Errorf("confidential clients cannot use token_endpoint_auth_method '%s'", TokenEndpointAuthMethodNone) + } + return nil } @@ -121,6 +151,11 @@ func (c *OAuthServerClient) IsConfidential() bool { return c.ClientType == OAuthServerClientTypeConfidential } +// GetTokenEndpointAuthMethod returns the token endpoint auth method +func (c *OAuthServerClient) GetTokenEndpointAuthMethod() string { + return c.TokenEndpointAuthMethod +} + // IsGrantTypeAllowed returns true if the client is allowed to use the specified grant type func (c *OAuthServerClient) IsGrantTypeAllowed(grantType string) bool { allowedTypes := c.GetGrantTypes() diff --git a/migrations/20251216000000_add_token_endpoint_auth_method.up.sql b/migrations/20251216000000_add_token_endpoint_auth_method.up.sql new file mode 100644 index 000000000..3802fa576 --- /dev/null +++ b/migrations/20251216000000_add_token_endpoint_auth_method.up.sql @@ -0,0 +1,18 @@ +-- Add token_endpoint_auth_method column to oauth_clients table +-- Per RFC 7591: "If unspecified or omitted, the default is 'client_secret_basic'" +-- For public clients, the default is 'none' since they don't have a client secret +/* auth_migration: 20251216000000 */ +alter table {{ index .Options "Namespace" }}.oauth_clients + add column if not exists token_endpoint_auth_method text check (token_endpoint_auth_method in ('client_secret_basic', 'client_secret_post', 'none')); + +-- Set default values for existing clients based on their client_type +update {{ index .Options "Namespace" }}.oauth_clients + set token_endpoint_auth_method = case + when client_type = 'public' then 'none' + else 'client_secret_basic' + end + where token_endpoint_auth_method is null; + +-- Now make the column not null +alter table {{ index .Options "Namespace" }}.oauth_clients + alter column token_endpoint_auth_method set not null;