// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package iamauth

import (
	"encoding/base64"
	"encoding/json"
	"fmt"
	"net/http"
	"net/textproto"
	"net/url"
	"strings"

	"github.com/hashicorp/consul-awsauth/internal/stringslice"
)

const (
	amzHeaderPrefix = "X-Amz-"
)

var defaultAllowedSTSRequestHeaders = []string{
	"X-Amz-Algorithm",
	"X-Amz-Content-Sha256",
	"X-Amz-Credential",
	"X-Amz-Date",
	"X-Amz-Security-Token",
	"X-Amz-Signature",
	"X-Amz-SignedHeaders",
}

// BearerToken is a login "token" for an IAM auth method. It is a signed
// sts:GetCallerIdentity request in JSON format. Optionally, it can include a
// signed embedded iam:GetRole or iam:GetUser request in the headers.
type BearerToken struct {
	config *Config

	getCallerIdentityMethod string
	getCallerIdentityURL    string
	getCallerIdentityHeader http.Header
	getCallerIdentityBody   string

	getIAMEntityMethod string
	getIAMEntityURL    string
	getIAMEntityHeader http.Header
	getIAMEntityBody   string

	entityRequestType       string
	parsedCallerIdentityURL *url.URL
	parsedIAMEntityURL      *url.URL
}

var _ json.Unmarshaler = (*BearerToken)(nil)

func NewBearerToken(loginToken string, config *Config) (*BearerToken, error) {
	token := &BearerToken{config: config}
	if err := json.Unmarshal([]byte(loginToken), &token); err != nil {
		return nil, fmt.Errorf("invalid token: %s", err)
	}

	if err := token.validate(); err != nil {
		return nil, err
	}

	if config.EnableIAMEntityDetails {
		method, err := token.getHeader(token.config.GetEntityMethodHeader)
		if err != nil {
			return nil, err
		}

		rawUrl, err := token.getHeader(token.config.GetEntityURLHeader)
		if err != nil {
			return nil, err
		}

		headerJson, err := token.getHeader(token.config.GetEntityHeadersHeader)
		if err != nil {
			return nil, err
		}

		var header http.Header
		if err := json.Unmarshal([]byte(headerJson), &header); err != nil {
			return nil, err
		}

		body, err := token.getHeader(token.config.GetEntityBodyHeader)
		if err != nil {
			return nil, err
		}

		parsedUrl, err := parseUrl(rawUrl)
		if err != nil {
			return nil, err
		}

		token.getIAMEntityMethod = method
		token.getIAMEntityBody = body
		token.getIAMEntityURL = rawUrl
		token.getIAMEntityHeader = header
		token.parsedIAMEntityURL = parsedUrl

		if err := token.validateIAMHostname(); err != nil {
			return nil, err
		}

		reqType, err := token.validateIAMEntityBody()
		if err != nil {
			return nil, err
		}

		if err := token.validateIAMEntityQueryParams(); err != nil {
			return nil, err
		}

		token.entityRequestType = reqType
	}
	return token, nil
}

// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1178
func (t *BearerToken) validate() error {
	if t.getCallerIdentityMethod != "POST" {
		return fmt.Errorf("iam_http_request_method must be POST")
	}
	if err := t.validateSTSHostname(); err != nil {
		return err
	}
	if err := t.validateGetCallerIdentityBody(); err != nil {
		return err
	}
	if err := t.validateGetCallerIdentityQueryParams(); err != nil {
		return err
	}
	if err := t.validateAllowedSTSHeaderValues(); err != nil {
		return err
	}
	return nil
}

// validateSTSHostname checks the CallerIdentityURL in the BearerToken
// either matches the admin configured STSEndpoint or, if STSEndpoint is not set,
// that the URL matches a known Amazon AWS hostname for the STS service, one of:
//
//	sts.amazonaws.com
//	sts.*.amazonaws.com
//	sts-fips.amazonaws.com
//	sts-fips.*.amazonaws.com
//
// See https://docs.aws.amazon.com/general/latest/gr/sts.html
func (t *BearerToken) validateSTSHostname() error {
	if t.config.STSEndpoint != "" {
		// If an STS endpoint is configured, we (elsewhere) send the request to that endpoint.
		return nil
	}
	if t.parsedCallerIdentityURL == nil {
		return fmt.Errorf("invalid GetCallerIdentity URL: %v", t.getCallerIdentityURL)
	}

	// Otherwise, validate the hostname looks like a known STS endpoint.
	host := t.parsedCallerIdentityURL.Hostname()
	if strings.HasSuffix(host, ".amazonaws.com") &&
		(strings.HasPrefix(host, "sts.") || strings.HasPrefix(host, "sts-fips.")) {
		return nil
	}
	return fmt.Errorf("invalid STS hostname: %q", host)
}

// validateIAMHostname checks the IAMEntityURL in the BearerToken
// either matches the admin configured IAMEndpoint or, if IAMEndpoint is not set,
// that the URL matches a known Amazon AWS hostname for the IAM service, one of:
//
//	iam.amazonaws.com
//	iam.*.amazonaws.com
//	iam-fips.amazonaws.com
//	iam-fips.*.amazonaws.com
//
// See https://docs.aws.amazon.com/general/latest/gr/iam-service.html
func (t *BearerToken) validateIAMHostname() error {
	if t.config.IAMEndpoint != "" {
		// If an IAM endpoint is configured, we (elsewhere) send the request to that endpoint.
		return nil
	}
	if t.parsedIAMEntityURL == nil {
		return fmt.Errorf("invalid IAM URL: %v", t.getIAMEntityURL)
	}

	// Otherwise, validate the hostname looks like a known IAM endpoint.
	host := t.parsedIAMEntityURL.Hostname()
	if strings.HasSuffix(host, ".amazonaws.com") &&
		(strings.HasPrefix(host, "iam.") || strings.HasPrefix(host, "iam-fips.")) {
		return nil
	}
	return fmt.Errorf("invalid IAM hostname: %q", host)
}

// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1439
func (t *BearerToken) validateGetCallerIdentityBody() error {
	allowedValues := url.Values{
		"Action": []string{"GetCallerIdentity"},
		// Will assume for now that future versions don't change
		// the semantics
		"Version": nil, // any value is allowed
	}
	if _, err := parseRequestBody(t.getCallerIdentityBody, allowedValues); err != nil {
		return fmt.Errorf("iam_request_body error: %s", err)
	}

	return nil
}

func (t *BearerToken) validateIAMEntityBody() (string, error) {
	allowedValues := url.Values{
		"Action":   []string{"GetRole", "GetUser"},
		"RoleName": nil, // any value is allowed
		"UserName": nil,
		"Version":  nil,
	}
	body, err := parseRequestBody(t.getIAMEntityBody, allowedValues)
	if err != nil {
		return "", fmt.Errorf("iam_request_headers[%s] error: %s", t.config.GetEntityBodyHeader, err)
	}

	// Disallow GetRole+UserName and GetUser+RoleName.
	action := body["Action"][0]
	_, hasRoleName := body["RoleName"]
	_, hasUserName := body["UserName"]
	if action == "GetUser" && hasUserName && !hasRoleName {
		return action, nil
	} else if action == "GetRole" && hasRoleName && !hasUserName {
		return action, nil
	}
	return "", fmt.Errorf("iam_request_headers[%q] error: invalid request body %q", t.config.GetEntityBodyHeader, t.getIAMEntityBody)
}

// validateGetCallerIdentityQueryParams validates that URL contains no query parameters
// to prevent bypass attacks where an attacker puts valid parameters in the request body
// but malicious ones in URL parameters. AWS STS will use URL parameters over body parameters.
func (t *BearerToken) validateGetCallerIdentityQueryParams() error {
	if t.parsedCallerIdentityURL == nil {
		return nil
	}

	// Reject any URL with query parameters to prevent parameter injection attacks
	if t.parsedCallerIdentityURL.RawQuery != "" {
		return fmt.Errorf("URL query parameters are not allowed for security reasons: found %q", t.parsedCallerIdentityURL.RawQuery)
	}

	return nil
}

// validateIAMEntityQueryParams validates that IAM entity URL contains no query parameters
// to prevent bypass attacks where an attacker could inject malicious parameters.
func (t *BearerToken) validateIAMEntityQueryParams() error {
	if t.parsedIAMEntityURL == nil {
		return nil
	}

	// Reject any URL with query parameters to prevent parameter injection attacks
	if t.parsedIAMEntityURL.RawQuery != "" {
		return fmt.Errorf("URL query parameters are not allowed for IAM entity requests: found %q", t.parsedIAMEntityURL.RawQuery)
	}

	return nil
}

// parseRequestBody parses the AWS STS or IAM request body, such as 'Action=GetRole&RoleName=my-role'.
// It returns the parsed values, or an error if there are unexpected fields based on allowedValues.
//
// A key-value pair in the body is allowed if:
//   - It is a single value (i.e. no bodies like 'Action=1&Action=2')
//   - allowedValues[key] is an empty slice or nil (any value is allowed for the key)
//   - allowedValues[key] is non-empty and contains the exact value
//
// This always requires an 'Action' field is present and non-empty.
func parseRequestBody(body string, allowedValues url.Values) (url.Values, error) {
	qs, err := url.ParseQuery(body)
	if err != nil {
		return nil, err
	}

	// Action field is always required.
	if _, ok := qs["Action"]; !ok || len(qs["Action"]) == 0 || qs["Action"][0] == "" {
		return nil, fmt.Errorf(`missing field "Action"`)
	}

	// Ensure the body does not have extra fields and each
	// field in the body matches the allowed values.
	for k, v := range qs {
		exp, ok := allowedValues[k]
		if k != "Action" && !ok {
			return nil, fmt.Errorf("unexpected field %q", k)
		}

		if len(exp) == 0 {
			// empty indicates any value is okay
			continue
		} else if len(v) != 1 || !stringslice.Contains(exp, v[0]) {
			return nil, fmt.Errorf("unexpected value %s=%v", k, v)
		}
	}

	return qs, nil
}

// https://github.com/hashicorp/vault/blob/861454e0ed1390d67ddaf1a53c1798e5e291728c/builtin/credential/aws/path_config_client.go#L349
func (t *BearerToken) validateAllowedSTSHeaderValues() error {
	for k := range t.getCallerIdentityHeader {
		h := textproto.CanonicalMIMEHeaderKey(k)
		if strings.HasPrefix(h, amzHeaderPrefix) &&
			!stringslice.Contains(defaultAllowedSTSRequestHeaders, h) &&
			!stringslice.Contains(t.config.AllowedSTSHeaderValues, h) {
			return fmt.Errorf("invalid request header: %s", h)
		}
	}
	return nil
}

// UnmarshalJSON unmarshals the bearer token details which contains an HTTP
// request (a signed sts:GetCallerIdentity request).
func (t *BearerToken) UnmarshalJSON(data []byte) error {
	var rawData struct {
		Method        string `json:"iam_http_request_method"`
		UrlBase64     string `json:"iam_request_url"`
		HeadersBase64 string `json:"iam_request_headers"`
		BodyBase64    string `json:"iam_request_body"`
	}

	if err := json.Unmarshal(data, &rawData); err != nil {
		return err
	}

	rawUrl, err := base64.StdEncoding.DecodeString(rawData.UrlBase64)
	if err != nil {
		return err
	}

	headersJson, err := base64.StdEncoding.DecodeString(rawData.HeadersBase64)
	if err != nil {
		return err
	}

	var headers http.Header
	// This is a JSON-string in JSON
	if err := json.Unmarshal(headersJson, &headers); err != nil {
		return err
	}

	body, err := base64.StdEncoding.DecodeString(rawData.BodyBase64)
	if err != nil {
		return err
	}

	t.getCallerIdentityMethod = rawData.Method
	t.getCallerIdentityBody = string(body)
	t.getCallerIdentityHeader = headers
	t.getCallerIdentityURL = string(rawUrl)

	parsedUrl, err := parseUrl(t.getCallerIdentityURL)
	if err != nil {
		return err
	}
	t.parsedCallerIdentityURL = parsedUrl
	return nil
}

func parseUrl(s string) (*url.URL, error) {
	u, err := url.Parse(s)
	if err != nil {
		return nil, err
	}
	// url.Parse doesn't error on empty string
	if u == nil || u.Scheme == "" || u.Host == "" {
		return nil, fmt.Errorf("url is invalid: %q", s)
	}
	return u, nil
}

// GetCallerIdentityRequest returns the sts:GetCallerIdentity request decoded
// from the bearer token.
func (t *BearerToken) GetCallerIdentityRequest() (*http.Request, error) {
	// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
	// We validate up-front that t.getCallerIdentityURL is a known AWS STS hostname.
	// Otherwise, we send to the admin-configured STSEndpoint.
	endpoint := t.getCallerIdentityURL
	if t.config.STSEndpoint != "" {
		endpoint = t.config.STSEndpoint
	}

	return buildHttpRequest(
		t.getCallerIdentityMethod,
		endpoint,
		t.parsedCallerIdentityURL,
		t.getCallerIdentityBody,
		t.getCallerIdentityHeader,
	)
}

// GetEntityRequest returns the iam:GetUser or iam:GetRole request from the request details,
// if present, embedded in the headers of the sts:GetCallerIdentity request.
func (t *BearerToken) GetEntityRequest() (*http.Request, error) {
	endpoint := t.getIAMEntityURL
	if t.config.IAMEndpoint != "" {
		endpoint = t.config.IAMEndpoint
	}

	return buildHttpRequest(
		t.getIAMEntityMethod,
		endpoint,
		t.parsedIAMEntityURL,
		t.getIAMEntityBody,
		t.getIAMEntityHeader,
	)
}

// getHeader returns the header from s.GetCallerIdentityHeader, or an error if
// the header is not found or is not a single value.
func (t *BearerToken) getHeader(name string) (string, error) {
	values := t.getCallerIdentityHeader.Values(name)
	if len(values) == 0 {
		return "", fmt.Errorf("missing header %q", name)
	}
	if len(values) != 1 {
		return "", fmt.Errorf("invalid value for header %q (expected 1 item)", name)
	}
	return values[0], nil
}

// buildHttpRequest returns an HTTP request from the given details.
// This supports sending to a custom endpoint, but always preserves the
// Host header and URI path, which are signed and cannot be modified.
// There's a deeper explanation of this in the Vault source code.
// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1569
func buildHttpRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*http.Request, error) {
	targetUrl, err := url.Parse(endpoint)
	if err != nil {
		return nil, err
	}
	joinedPath, err := url.JoinPath(targetUrl.Path, parsedUrl.Path)
	if err != nil {
		return nil, err
	}
	targetUrl.Path = joinedPath
	targetUrl.RawQuery = parsedUrl.RawQuery

	request, err := http.NewRequest(method, targetUrl.String(), strings.NewReader(body))
	if err != nil {
		return nil, err
	}
	request.Host = parsedUrl.Host
	for k, vals := range headers {
		for _, val := range vals {
			request.Header.Add(k, val)
		}
	}
	return request, nil
}
