// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package oidc

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"math/big"
	"os"
	"path"
	"path/filepath"
	"testing"
	"time"

	"github.com/hashicorp/nomad/nomad/structs"
	"github.com/shoenig/test/must"
)

func TestBuildClientAssertionJWT_ClientSecret(t *testing.T) {

	tests := []struct {
		name        string
		config      *structs.ACLAuthMethodConfig
		wantErr     bool
		expectedErr string
	}{
		{
			name: "valid client secret",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:     "test-client-id",
				OIDCClientSecret: "1234567890abcdefghijklmnopqrstuvwxyz",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceClientSecret,
					KeyAlgorithm: "HS256",
					Audience:     []string{"test-audience"},
					ExtraHeaders: map[string]string{
						"test-header": "test-value",
					},
				},
			},
			wantErr: false,
		},
		{
			name: "nil config",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:        "test-client-id",
				OIDCClientSecret:    "test-client-secret",
				OIDCClientAssertion: nil,
			},
			wantErr:     true,
			expectedErr: `no auth method config or client assertion`,
		},
		{
			name: "invalid client secret length",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:     "test-client-id",
				OIDCClientSecret: "test-client-secret",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceClientSecret,
					KeyAlgorithm: "HS256",
					Audience:     []string{"test-audience"},
					ExtraHeaders: map[string]string{
						"test-header": "test-value",
					},
				},
			},
			wantErr:     true,
			expectedErr: `invalid secret length for algorithm: "HS256" must be at least 32 bytes long`,
		},
		{
			name: "invalid client secret kid in extra header",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:     "test-client-id",
				OIDCClientSecret: "1234567890abcdefghijklmnopqrstuvwxyz",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceClientSecret,
					KeyAlgorithm: "HS256",
					Audience:     []string{"test-audience"},
					ExtraHeaders: map[string]string{
						"kid": "test-kid",
					},
				},
			},
			wantErr:     true,
			expectedErr: `WithHeaders: "kid" not allowed in WithHeaders; use WithKeyID instead`,
		},
		{
			name: "invalid key algorithm none",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:     "test-client-id",
				OIDCClientSecret: "1234567890abcdefghijklmnopqrstuvwxyz",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceClientSecret,
					KeyAlgorithm: "none",
					Audience:     []string{"test-audience"},
					ExtraHeaders: map[string]string{
						"test-header": "test-value",
					},
				},
			},
			wantErr:     true,
			expectedErr: `unsupported algorithm "none" for client secret`,
		},
		{
			name: "invalid key algorithm None",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:     "test-client-id",
				OIDCClientSecret: "1234567890abcdefghijklmnopqrstuvwxyz",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceClientSecret,
					KeyAlgorithm: "None",
					Audience:     []string{"test-audience"},
					ExtraHeaders: map[string]string{
						"test-header": "test-value",
					},
				},
			},
			wantErr:     true,
			expectedErr: `unsupported algorithm "None" for client secret`,
		},
		// expected non-nil error; got nil
		{
			name: "invalid missing audience",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID:     "test-client-id",
				OIDCClientSecret: "1234567890abcdefghijklmnopqrstuvwxyz",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceClientSecret,
					KeyAlgorithm: "HS256",
					ExtraHeaders: map[string]string{
						"test-header": "test-value",
					},
				},
			},
			wantErr:     true,
			expectedErr: "missing Audience",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.config.Canonicalize() // inherits ClientSecret from OIDCClientAssertion
			jwt, err := BuildClientAssertionJWT(tt.config, nil, "")
			if tt.wantErr {
				must.Error(t, err)
				must.StrContains(t, err.Error(), tt.expectedErr)
			} else {
				must.NoError(t, err)
				must.NotNil(t, jwt)
			}
		})
	}
}

func TestBuildClientAssertionJWT_PrivateKey(t *testing.T) {
	nomadKey := generateTestPrivateKey(t)
	nomadKeyPath := writeTestPrivateKeyToFile(t, nomadKey)
	nomadKID := "anything"
	nomadCert := generateTestCertificate(t, nomadKey)
	nomadCertPath := writeTestCertToFile(t, nomadCert)
	nonKeyCertFile := path.Join(t.TempDir(), "bad.key.cert.pem")
	must.NoError(t, os.WriteFile(nonKeyCertFile, []byte("not a key or cert"), 0644))

	tests := []struct {
		name        string
		config      *structs.ACLAuthMethodConfig
		wantErr     bool
		expectedErr string
	}{
		{
			name: "nil config",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientAssertion: nil,
			},
			wantErr: true,
		},
		{
			name: "valid private key source with pem key",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKey: encodeTestPrivateKey(nomadKey),
						KeyID:  nomadKID,
					},
				},
			},
			wantErr: false,
		},
		{
			name: "valid private key source with pem key with pem cert",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKey:  encodeTestPrivateKey(nomadKey),
						PemCert: encodeTestCert(nomadCert),
					},
				},
			},
			wantErr: false,
		},
		{
			name: "valid private key source with pem cert file",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKey:      encodeTestPrivateKey(nomadKey),
						PemCertFile: nomadCertPath,
					},
				},
			},
			wantErr: false,
		},
		{
			name: "valid private key source with pem key file with Key ID",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKeyFile: nomadKeyPath,
						KeyID:      nomadKID,
					},
				},
			},
			wantErr: false,
		},
		// invalid pem key file location
		{
			name: "invalid private key source with pem key file",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKeyFile: nomadKeyPath + "/invalid",
						KeyID:      nomadKID,
					},
				},
			},
			wantErr:     true,
			expectedErr: "error reading PemKeyFile",
		},
		// file does exist but is not an rsa key
		{
			name: "invalid private key file contents",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKeyFile: nonKeyCertFile,
						KeyID:      nomadKID,
					},
				},
			},
			wantErr:     true,
			expectedErr: "error parsing PemKeyFile: invalid key:",
		},
		{
			name: "invalid certificate file contents",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKey:      encodeTestPrivateKey(nomadKey),
						PemCertFile: nonKeyCertFile,
					},
				},
			},
			wantErr:     true,
			expectedErr: "failed to decode PemCertFile PEM block",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.config.Canonicalize() // inherits clientSecret from OIDCClientAssertion
			jwt, err := BuildClientAssertionJWT(tt.config, nomadKey, nomadKID)
			if tt.wantErr {
				must.Error(t, err)
				must.StrContains(t, err.Error(), tt.expectedErr)
			} else {
				must.NoError(t, err)
				must.NotNil(t, jwt)
			}
		})
	}
}

func TestBuildClientAssertionJWT_NomadKey(t *testing.T) {
	nomadKey := generateTestPrivateKey(t)
	nomadKID := "anything"

	tests := []struct {
		name    string
		config  *structs.ACLAuthMethodConfig
		wantErr bool
	}{
		{
			name: "nil config",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientAssertion: nil,
			},
			wantErr: true,
		},
		{
			name: "nomad key source",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourceNomad,
					KeyAlgorithm: "RS256",
					Audience:     []string{"test-audience"},
				},
			},
			wantErr: false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.config.Canonicalize() // inherits clientSecret from OIDCClientAssertion
			jwt, err := BuildClientAssertionJWT(tt.config, nomadKey, nomadKID)
			if tt.wantErr {
				must.Error(t, err)
			} else {
				must.NoError(t, err)
				must.NotNil(t, jwt)
			}
		})
	}
}

func TestBuildClientAssertionJWT_PrivateKeyExpiredCert(t *testing.T) {
	nomadKey := generateTestPrivateKey(t)
	nomadInvalidKey := generateInvalidTestPrivateKey(t)
	nomadKID := "anything"
	nomadCert := generateTestCertificate(t, nomadKey)
	nomadExpiredCert := generateExpiredTestCertificate(t, nomadKey)

	tests := []struct {
		name        string
		config      *structs.ACLAuthMethodConfig
		wantErr     bool
		expectedErr string
	}{
		{
			name: "invalid PemKey",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKey:  encodeTestPrivateKey(nomadInvalidKey),
						PemCert: encodeTestCert(nomadCert),
					},
				},
			},
			wantErr:     true,
			expectedErr: "failed to parse private key",
		},
		{
			name: "expired certificate PemCert",
			config: &structs.ACLAuthMethodConfig{
				OIDCClientID: "test-client-id",
				OIDCClientAssertion: &structs.OIDCClientAssertion{
					KeySource:    structs.OIDCKeySourcePrivateKey,
					Audience:     []string{"test-audience"},
					KeyAlgorithm: "RS256",
					PrivateKey: &structs.OIDCClientAssertionKey{
						PemKey:  encodeTestPrivateKey(nomadKey),
						PemCert: encodeTestCert(nomadExpiredCert),
					},
				},
			},
			wantErr:     true,
			expectedErr: "certificate has expired or is not yet valid",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.config.Canonicalize() // inherits clientSecret from OIDCClientAssertion
			jwt, err := BuildClientAssertionJWT(tt.config, nomadKey, nomadKID)
			if tt.wantErr {
				must.Error(t, err)
				must.StrContains(t, err.Error(), tt.expectedErr)
			} else {
				must.NoError(t, err)
				must.NotNil(t, jwt)
			}
		})
	}
}

func generateTestPrivateKey(t *testing.T) *rsa.PrivateKey {
	key, err := rsa.GenerateKey(rand.Reader, 2048)
	must.NoError(t, err)
	return key
}

func writeTestPrivateKeyToFile(t *testing.T, key *rsa.PrivateKey) string {
	tmpDir := t.TempDir()
	keyPath := filepath.Join(tmpDir, "testkey.pem")

	keyFile, err := os.Create(keyPath)
	must.NoError(t, err)
	defer keyFile.Close()

	keyBytes := x509.MarshalPKCS1PrivateKey(key)
	block := &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: keyBytes,
	}
	err = pem.Encode(keyFile, block)
	must.NoError(t, err)

	return keyPath
}

func encodeTestPrivateKey(key *rsa.PrivateKey) string {
	keyBytes := x509.MarshalPKCS1PrivateKey(key)
	block := &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: keyBytes,
	}
	return string(pem.EncodeToMemory(block))
}

func generateTestCertificate(t *testing.T, key *rsa.PrivateKey) *x509.Certificate {
	template := &x509.Certificate{
		SerialNumber: big.NewInt(1),
		NotBefore:    time.Now(),
		NotAfter:     time.Now().Add(time.Hour),
	}
	certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
	must.NoError(t, err)

	cert, err := x509.ParseCertificate(certDER)
	must.NoError(t, err)

	return cert
}

func encodeTestCert(cert *x509.Certificate) string {
	block := &pem.Block{
		Type:  "CERTIFICATE",
		Bytes: cert.Raw,
	}
	return string(pem.EncodeToMemory(block))
}

func writeTestCertToFile(t *testing.T, cert *x509.Certificate) string {
	tmpDir := t.TempDir()
	certPath := filepath.Join(tmpDir, "testcert.pem")

	certFile, err := os.Create(certPath)
	must.NoError(t, err)
	defer certFile.Close()

	block := &pem.Block{
		Type:  "CERTIFICATE",
		Bytes: cert.Raw,
	}
	err = pem.Encode(certFile, block)
	must.NoError(t, err)

	return certPath
}

func generateExpiredTestCertificate(t *testing.T, key *rsa.PrivateKey) *x509.Certificate {
	template := &x509.Certificate{
		SerialNumber: big.NewInt(1),
		NotBefore:    time.Now().Add(-2 * time.Hour),
		NotAfter:     time.Now().Add(-1 * time.Hour),
	}
	certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
	must.NoError(t, err)

	cert, err := x509.ParseCertificate(certDER)
	must.NoError(t, err)

	return cert
}

func generateInvalidTestPrivateKey(t *testing.T) *rsa.PrivateKey {
	key, err := rsa.GenerateKey(rand.Reader, 2048)
	must.NoError(t, err)

	// Simulate an invalid key by modifying the key's modulus
	key.N = big.NewInt(0) // This is just a placeholder to simulate an invalid key

	return key
}
func TestNewlineHeaders(t *testing.T) {
	cases := []struct {
		name    string
		content string
		expect  string
	}{
		{
			name:    "empty",
			content: "",
			expect:  "",
		},
		{
			name:    "nonsense",
			content: "not a key or cert",
			expect:  "not a key or cert",
		},
		{
			name:    "pem-shaped nonsense",
			content: "-----BEGIN RANDOM PEM-----stuff-----END RANDOM PEM-----",
			expect:  "-----BEGIN RANDOM PEM-----stuff\n-----END RANDOM PEM-----",
		},
		{
			name:    "no newlines key",
			content: "-----BEGIN ANY KIND OF PRIVATE KEY-----stuff-----END ANY PRIVATE KEY-----",
			expect:  "-----BEGIN ANY KIND OF PRIVATE KEY-----\nstuff\n-----END ANY PRIVATE KEY-----",
		},
		{
			name:    "no newlines cert",
			content: "-----BEGIN ANY KIND OF CERTIFICATE-----stuff-----END ANY CERTIFICATE-----",
			expect:  "-----BEGIN ANY KIND OF CERTIFICATE-----\nstuff\n-----END ANY CERTIFICATE-----",
		},
		// extra newlines between header/footer and content is okay.
		{
			name:    "with newlines key",
			content: "-----BEGIN ANY KIND OF PRIVATE KEY-----\nstuff\n-----END ANY PRIVATE KEY-----",
			expect:  "-----BEGIN ANY KIND OF PRIVATE KEY-----\n\nstuff\n\n-----END ANY PRIVATE KEY-----",
		},
		{
			name:    "with newlines cert",
			content: "-----BEGIN ANY KIND OF CERTIFICATE-----\nstuff\nmore\nstuff\n-----END ANY CERTIFICATE-----",
			expect:  "-----BEGIN ANY KIND OF CERTIFICATE-----\n\nstuff\nmore\nstuff\n\n-----END ANY CERTIFICATE-----",
		},
		// extra junk outside the header/footer is okay.
		{
			name:    "extra junk key",
			content: "note to self\n-----BEGIN ANY KIND OF PRIVATE KEY-----\nstuff\n-----END ANY PRIVATE KEY-----\nanother note",
			expect:  "note to self\n\n-----BEGIN ANY KIND OF PRIVATE KEY-----\n\nstuff\n\n-----END ANY PRIVATE KEY-----\n\nanother note",
		},
		{
			name:    "extra junk cert",
			content: "note to self\n-----BEGIN ANY KIND OF CERTIFICATE-----\nstuff\n-----END ANY CERTIFICATE-----\nanother note",
			expect:  "note to self\n\n-----BEGIN ANY KIND OF CERTIFICATE-----\n\nstuff\n\n-----END ANY CERTIFICATE-----\n\nanother note",
		},
	}
	for _, tc := range cases {
		t.Run(tc.name, func(t *testing.T) {
			got := newlineHeaders([]byte(tc.content))
			must.Eq(t, tc.expect, string(got))
		})
	}
}
