/**
 * Copyright (C) 2015 MongoDB Inc.
 */

#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage

#include "mongo/platform/basic.h"

#include "kmip_service.h"

#include <memory>
#include <vector>

#include "encryption_options.h"
#include "kmip_consts.h"
#include "kmip_request.h"
#include "kmip_response.h"
#include "mongo/base/data_range_cursor.h"
#include "mongo/base/data_type_endian.h"
#include "mongo/base/init.h"
#include "mongo/base/status_with.h"
#include "mongo/config.h"
#include "mongo/crypto/symmetric_crypto.h"
#include "mongo/logv2/log.h"
#include "mongo/platform/mutex.h"
#include "mongo/util/net/ssl_manager.h"
#include "mongo/util/net/ssl_options.h"
#include "mongo/util/secure_zero_memory.h"

namespace mongo {
namespace kmip {

StatusWith<KMIPService> KMIPService::createKMIPService(const HostAndPort& server,
                                                       const SSLParams& sslKMIPParams,
                                                       Milliseconds connectTimeout) {
    try {
        std::shared_ptr<SSLManagerInterface> sslManager =
            SSLManagerInterface::create(sslKMIPParams, false);
        KMIPService kmipService(server, std::move(sslManager));
        Status status = kmipService._initServerConnection(connectTimeout);
        if (!status.isOK()) {
            return status;
        }
        return std::move(kmipService);
    } catch (const DBException& e) {
        return e.toStatus();
    }
}

StatusWith<KMIPService> KMIPService::createKMIPService(const KMIPParams& kmipParams,
                                                       bool sslFIPSMode) {
    SSLParams sslKMIPParams;

    // KMIP specific parameters.
    sslKMIPParams.sslPEMKeyFile = kmipParams.kmipClientCertificateFile;
    sslKMIPParams.sslPEMKeyPassword = kmipParams.kmipClientCertificatePassword;
    sslKMIPParams.sslClusterFile = "";
    sslKMIPParams.sslClusterPassword = "";
    sslKMIPParams.sslCAFile = kmipParams.kmipServerCAFile;
#ifdef MONGO_CONFIG_SSL_CERTIFICATE_SELECTORS
    sslKMIPParams.sslCertificateSelector = kmipParams.kmipClientCertificateSelector;
#endif
    sslKMIPParams.sslFIPSMode = sslFIPSMode;

    // KMIP servers never should have invalid certificates
    sslKMIPParams.sslAllowInvalidCertificates = false;
    sslKMIPParams.sslAllowInvalidHostnames = false;
    sslKMIPParams.sslCRLFile = "";

    Milliseconds connectTimeout(kmipParams.kmipConnectTimeoutMS);

    // Repeat iteration through the list one or more times.
    auto retries = kmipParams.kmipConnectRetries;
    do {
        // iterate through the list of provided KMIP servers until a valid one is found
        for (auto it = kmipParams.kmipServerName.begin(); it != kmipParams.kmipServerName.end();
             ++it) {
            HostAndPort hp(*it, kmipParams.kmipPort);
            auto swService = createKMIPService(hp, sslKMIPParams, connectTimeout);
            if (swService.isOK()) {
                return std::move(swService.getValue());
            }
            if ((it + 1) != kmipParams.kmipServerName.end()) {
                LOGV2_WARNING(24240,
                              "Connection to KMIP server failed. Trying next server",
                              "failedHost"_attr = hp,
                              "nextHost"_attr = HostAndPort(*(it + 1), kmipParams.kmipPort));
            } else if (retries) {
                LOGV2_WARNING(24241,
                              "Connection to KMIP server failed. Restarting connect "
                              "attempt(s) with remaining retries",
                              "host"_attr = hp,
                              "retryCount"_attr = retries);
            } else {
                return swService.getStatus();
            }
        }
    } while (retries--);

    // Only reachable if the server name list is empty.
    return {ErrorCodes::BadValue, "No KMIP server specified."};
}

KMIPService::KMIPService(const HostAndPort& server, std::shared_ptr<SSLManagerInterface> sslManager)
    : _sslManager(std::move(sslManager)),
      _server(server),
      _socket(std::make_unique<Socket>(10, logv2::LogSeverity::Log())) {}

StatusWith<std::string> KMIPService::createExternalKey() {
    StatusWith<KMIPResponse> swResponse = _sendRequest(_generateKMIPCreateRequest());
    if (!swResponse.isOK()) {
        return Status(ErrorCodes::BadValue,
                      str::stream()
                          << "KMIP create key failed: " << swResponse.getStatus().reason());
    }

    const KMIPResponse& response = swResponse.getValue();
    if (response.getResultStatus() != kmip::statusSuccess) {
        return Status(ErrorCodes::BadValue,
                      str::stream()
                          << "KMIP create key failed, code: " << response.getResultReason()
                          << " error: " << response.getResultMsg());
    }
    return response.getUID();
}

StatusWith<std::unique_ptr<SymmetricKey>> KMIPService::getExternalKey(const std::string& uid) {
    StatusWith<KMIPResponse> swResponse = _sendRequest(_generateKMIPGetRequest(uid));

    if (!swResponse.isOK()) {
        return swResponse.getStatus();
    }

    KMIPResponse response = std::move(swResponse.getValue());
    if (response.getResultStatus() != kmip::statusSuccess) {
        return Status(ErrorCodes::BadValue,
                      str::stream() << "KMIP get key '" << uid
                                    << "' failed, code: " << response.getResultReason()
                                    << " error: " << response.getResultMsg());
    }

    std::unique_ptr<SymmetricKey> key = response.getSymmetricKey();
    if (key->getKeySize() != crypto::sym256KeySize) {
        return Status(ErrorCodes::BadValue,
                      str::stream() << "KMIP got a key which was " << key->getKeySize() * 8
                                    << " bits long, but a " << crypto::sym256KeySize * 8
                                    << " bit key is required");
    }
    return std::move(key);
}

Status KMIPService::_initServerConnection(Milliseconds connectTimeout) {
    SockAddr server(_server.host().c_str(), _server.port(), AF_UNSPEC);

    if (!server.isValid()) {
        return Status(ErrorCodes::BadValue,
                      str::stream() << "KMIP server address " << _server.host() << " is invalid.");
    }

    if (!_socket->connect(server, connectTimeout)) {
        return Status(ErrorCodes::BadValue,
                      str::stream() << "Could not connect to KMIP server " << server.toString());
    }

    if (!_socket->secure(_sslManager.get(), _server.host())) {
        return Status(ErrorCodes::BadValue,
                      str::stream() << "Failed to perform SSL handshake with the KMIP server "
                                    << _server.toString());
    }

    return Status::OK();
}

// Sends a request message to the KMIP server and creates a KMIPResponse.
StatusWith<KMIPResponse> KMIPService::_sendRequest(const std::vector<uint8_t>& request) {
    char resp[2000];

    _socket->send(reinterpret_cast<const char*>(request.data()), request.size(), "KMIP request");
    /**
     *  Read response header on the form:
     *  data[0:2] - tag identifier
     *  data[3]   - tag type
     *  data[4:7] - big endian encoded message body length
     */
    _socket->recv(resp, 8);
    if (memcmp(resp, kmip::responseMessageTag, 3) != 0 ||
        resp[3] != static_cast<char>(ItemType::structure)) {
        return Status(ErrorCodes::FailedToParse,
                      "Expected KMIP response message to start with"
                      "reponse message tag");
    }

    ConstDataRangeCursor cdrc(resp + 4, resp + 8);
    StatusWith<BigEndian<uint32_t>> swBodyLength = cdrc.readAndAdvance<BigEndian<uint32_t>>();
    if (!swBodyLength.isOK()) {
        return swBodyLength.getStatus();
    }

    uint32_t bodyLength = static_cast<uint32_t>(swBodyLength.getValue());
    massert(4044, "KMIP server response is too long", bodyLength <= sizeof(resp) - 8);
    _socket->recv(&resp[8], bodyLength);

    StatusWith<KMIPResponse> swKMIPResponse = KMIPResponse::create(resp, bodyLength + 8);
    secureZeroMemory(resp, bodyLength + 8);
    return swKMIPResponse;
}

std::vector<uint8_t> KMIPService::_generateKMIPGetRequest(const std::string& uid) {
    std::vector<uint8_t> uuid(std::begin(uid), std::end(uid));
    mongo::kmip::GetKMIPRequestParameters getRequestParams(uuid);
    return encodeKMIPRequest(getRequestParams);
}

std::vector<uint8_t> KMIPService::_generateKMIPCreateRequest() {
    std::vector<uint8_t> algorithm(std::begin(aesCryptoAlgorithm), std::end(aesCryptoAlgorithm));
    std::vector<uint8_t> length = convertIntToBigEndianArray(256);
    std::vector<uint8_t> usageMask{
        0x00, 0x00, 0x00, cryptoUsageMaskEncrypt | cryptoUsageMaskDecrypt};

    CreateKMIPRequestParameters createRequestParams(algorithm, length, usageMask);
    return encodeKMIPRequest(createRequestParams);
}
}  // namespace kmip
}  // namespace mongo
