Skip to content

Commit

Permalink
Add client X509 certificate based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
halfgaar committed Nov 4, 2023
1 parent abbdc0c commit de2aa39
Show file tree
Hide file tree
Showing 17 changed files with 229 additions and 8 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ add_executable(flashmq
dnsresolver.cpp
bridgeinfodb.h bridgeinfodb.cpp
globber.cpp
x509manager.h x509manager.cpp

)

Expand Down
2 changes: 2 additions & 0 deletions FlashMQTests/FlashMQTests.pro
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ SOURCES += tst_maintests.cpp \
../dnsresolver.cpp \
../bridgeinfodb.cpp \
../globber.cpp \
../x509manager.cpp \
conffiletemp.cpp \
dnstests.cpp \
filecloser.cpp \
Expand Down Expand Up @@ -133,6 +134,7 @@ HEADERS += \
../dnsresolver.h \
../bridgeinfodb.h \
../globber.h \
../x509manager.h \
conffiletemp.h \
filecloser.h \
flashmqtempdir.h \
Expand Down
2 changes: 1 addition & 1 deletion bridgeconfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void BridgeConfig::isValid()
port = 8883;
}

testSslVerifyLocations(caFile, caDir);
testSslVerifyLocations(caFile, caDir, "Loading bridge ca_file/ca_dir failed.");
}
else
{
Expand Down
60 changes: 60 additions & 0 deletions client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,3 +958,63 @@ std::string &Client::getMutableUsername()
return this->username;
}

void Client::setSslVerify(X509ClientVerification verificationMode)
{
const int mode = verificationMode > X509ClientVerification::None ? SSL_VERIFY_PEER : SSL_VERIFY_NONE;
this->x509ClientVerification = verificationMode;
ioWrapper.setSslVerify(mode, "");
}

std::optional<std::string> Client::getUsernameFromPeerCertificate()
{
if (!ioWrapper.isSsl() || x509ClientVerification == X509ClientVerification::None)
return std::optional<std::string>();

X509Manager client_cert = ioWrapper.getPeerCertificate();

if (!client_cert)
throw ProtocolError("Client did not provide X509 peer certificate", ReasonCodes::BadUserNameOrPassword);

X509_NAME *x509_name = X509_get_subject_name(client_cert.get());
int index = X509_NAME_get_index_by_NID(x509_name, NID_commonName, -1);

if (index < 0)
return std::optional<std::string>();

X509_NAME_ENTRY *name_entry = X509_NAME_get_entry(x509_name, index);

if (!name_entry)
throw std::runtime_error("X509_NAME_get_entry failed. This should be impossible.");

ASN1_STRING *asn1_string = X509_NAME_ENTRY_get_data(name_entry);

if (!asn1_string)
throw std::runtime_error("Cannot obtain asn1 string from x509 certificate.");

const unsigned char *str = ASN1_STRING_get0_data(asn1_string);

if (!str)
throw std::runtime_error("ASN1_STRING_get0_data failed. This should be impossible.");

std::string username(reinterpret_cast<const char*>(str));

if (!isValidUtf8(username))
throw ProtocolError("Common name from peer certificate is not valid UTF8.", ReasonCodes::MalformedPacket);

return username;
}

X509ClientVerification Client::getX509ClientVerification() const
{
return x509ClientVerification;
}










7 changes: 7 additions & 0 deletions client.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ See LICENSE for license details.
#include <mutex>
#include <iostream>
#include <time.h>
#include <optional>

#include <openssl/ssl.h>
#include <openssl/err.h>
Expand All @@ -28,6 +29,7 @@ See LICENSE for license details.
#include "types.h"
#include "iowrapper.h"
#include "bridgeconfig.h"
#include "enums.h"

#include "publishcopyfactory.h"

Expand Down Expand Up @@ -85,6 +87,7 @@ class Client
std::string username;
uint16_t keepalive = 10;
bool clean_start = false;
X509ClientVerification x509ClientVerification = X509ClientVerification::None;

std::shared_ptr<WillPublish> stagedWillPublish;
std::shared_ptr<WillPublish> willPublish;
Expand Down Expand Up @@ -217,6 +220,10 @@ class Client
void setFakeUpgraded();
#endif

void setSslVerify(X509ClientVerification verificationMode);
std::optional<std::string> getUsernameFromPeerCertificate();
X509ClientVerification getX509ClientVerification() const;

};

#endif // CLIENT_H
18 changes: 18 additions & 0 deletions configfileparser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ ConfigFileParser::ConfigFileParser(const std::string &path) :
validListenKeys.insert("inet4_bind_address");
validListenKeys.insert("inet6_bind_address");
validListenKeys.insert("haproxy");
validListenKeys.insert("client_verification_ca_file");
validListenKeys.insert("client_verification_ca_dir");
validListenKeys.insert("client_verification_still_do_authn");

validBridgeKeys.insert("local_username");
validBridgeKeys.insert("remote_username");
Expand Down Expand Up @@ -358,6 +361,8 @@ void ConfigFileParser::loadFile(bool test)

std::string key = matches[1].str();
const std::string value = matches[2].str();
std::string valueTrimmed = value;
trim(valueTrimmed);

try
{
Expand Down Expand Up @@ -407,6 +412,19 @@ void ConfigFileParser::loadFile(bool test)
bool val = stringTruthiness(value);
curListener->haproxy = val;
}
if (testKeyValidity(key, "client_verification_ca_file", validListenKeys))
{
curListener->clientVerificationCaFile = valueTrimmed;
}
if (testKeyValidity(key, "client_verification_ca_dir", validListenKeys))
{
curListener->clientVerificationCaDir = valueTrimmed;
}
if (testKeyValidity(key, "client_verification_still_do_authn", validListenKeys))
{
bool val = stringTruthiness(value);
curListener->clientVerifictionStillDoAuthn = val;
}

continue;
}
Expand Down
6 changes: 6 additions & 0 deletions enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ See LICENSE for license details.
#ifndef ENUMS_H
#define ENUMS_H

enum class X509ClientVerification
{
None,
X509IsEnough,
X509AndUsernamePassword
};

#endif // ENUMS_H
17 changes: 13 additions & 4 deletions iowrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,14 @@ void IoWrapper::setSslVerify(int mode, const std::string &hostname)

SSL_set_hostflags(ssl, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);

if (!SSL_set1_host(ssl, hostname.c_str()))
throw std::runtime_error("Failed setting hostname of SSL context.");
if (!hostname.empty())
{
if (!SSL_set1_host(ssl, hostname.c_str()))
throw std::runtime_error("Failed setting hostname of SSL context.");

if (SSL_set_tlsext_host_name(ssl, hostname.c_str()) != 1)
throw std::runtime_error("Failed setting SNI hostname of SSL context.");
if (SSL_set_tlsext_host_name(ssl, hostname.c_str()) != 1)
throw std::runtime_error("Failed setting SNI hostname of SSL context.");
}

SSL_set_verify(ssl, mode, verify_callback);
}
Expand Down Expand Up @@ -268,6 +271,12 @@ WebsocketState IoWrapper::getWebsocketState() const
return websocketState;
}

X509Manager IoWrapper::getPeerCertificate() const
{
X509Manager result(this->ssl);
return result;
}

bool IoWrapper::needsHaProxyParsing() const
{
return _needsHaProxyParsing;
Expand Down
2 changes: 2 additions & 0 deletions iowrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ See LICENSE for license details.
#include "logger.h"
#include "haproxy.h"
#include "cirbuf.h"
#include "x509manager.h"

#define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2
#define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10
Expand Down Expand Up @@ -136,6 +137,7 @@ class IoWrapper
bool hasProcessedBufferedBytesToRead() const;
bool isWebsocket() const;
WebsocketState getWebsocketState() const;
X509Manager getPeerCertificate() const;

bool needsHaProxyParsing() const;
HaProxyConnectionType readHaProxyData(int fd, struct sockaddr *addr);
Expand Down
38 changes: 38 additions & 0 deletions listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ it under the terms of The Open Software License 3.0 (OSL-3.0).
See LICENSE for license details.
*/

#include <openssl/err.h>

#include "listener.h"

#include "utils.h"
#include "exceptions.h"
#include "logger.h"

void Listener::isValid()
{
Expand All @@ -26,6 +29,7 @@ void Listener::isValid()
}

testSsl(sslFullchain, sslPrivkey);
testSslVerifyLocations(clientVerificationCaFile, clientVerificationCaDir, "Loading client_verification_ca_dir/client_verification_ca_file failed.");
}
else
{
Expand All @@ -38,6 +42,11 @@ void Listener::isValid()
}
}

if ((!clientVerificationCaDir.empty() || !clientVerificationCaFile.empty()) && !isSsl())
{
throw ConfigFileException("X509 client verification can only be done on TLS listeners.");
}

if (port <= 0 || port > 65534)
{
throw ConfigFileException(formatString("Port nr %d is not valid", port));
Expand Down Expand Up @@ -98,6 +107,35 @@ void Listener::loadCertAndKeyFromConfig()
throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected.");
if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)
throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");

{
const char *ca_file = clientVerificationCaFile.empty() ? nullptr : clientVerificationCaFile.c_str();
const char *ca_dir = clientVerificationCaDir.empty() ? nullptr : clientVerificationCaDir.c_str();

if (ca_file || ca_dir)
{
if (SSL_CTX_load_verify_locations(sslctx->get(), ca_file, ca_dir) != 1)
{
ERR_print_errors_cb(logSslError, NULL);
throw std::runtime_error("Loading client_verification_ca_dir/client_verification_ca_file failed. "
"This was after test loading the certificate, so is very unexpected.");
}
}
}
}

X509ClientVerification Listener::getX509ClientVerficiationMode() const
{
X509ClientVerification result = X509ClientVerification::None;
const bool clientCADefined = !clientVerificationCaDir.empty() || !clientVerificationCaFile.empty();

if (clientCADefined)
result = X509ClientVerification::X509IsEnough;

if (result >= X509ClientVerification::X509IsEnough && clientVerifictionStillDoAuthn)
result = X509ClientVerification::X509AndUsernamePassword;

return result;
}

std::string Listener::getBindAddress(ListenerProtocol p)
Expand Down
5 changes: 5 additions & 0 deletions listener.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ See LICENSE for license details.
#include <memory>

#include "sslctxmanager.h"
#include "enums.h"

enum class ListenerProtocol
{
Expand All @@ -33,13 +34,17 @@ struct Listener
bool haproxy = false;
std::string sslFullchain;
std::string sslPrivkey;
std::string clientVerificationCaFile;
std::string clientVerificationCaDir;
bool clientVerifictionStillDoAuthn = false;
std::unique_ptr<SslCtxManager> sslctx;

void isValid();
bool isSsl() const;
bool isHaProxy() const;
std::string getProtocolName() const;
void loadCertAndKeyFromConfig();
X509ClientVerification getX509ClientVerficiationMode() const;

std::string getBindAddress(ListenerProtocol p);
};
Expand Down
5 changes: 5 additions & 0 deletions mainapp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,11 @@ void MainApp::start()
// Don't use std::make_shared to avoid the weak pointers keeping the control block in memory.
std::shared_ptr<Client> client = std::shared_ptr<Client>(new Client(fd, thread_data, clientSSL, listener->websocket, listener->isHaProxy(), addr, settings));

if (listener->getX509ClientVerficiationMode() != X509ClientVerification::None)
{
client->setSslVerify(listener->getX509ClientVerficiationMode());
}

thread_data->giveClient(std::move(client));

globalStats->socketConnects.inc();
Expand Down
16 changes: 16 additions & 0 deletions mqttpacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,17 @@ void MqttPacket::handleConnect()

ConnectData connectData = parseConnectData();

if (sender->getX509ClientVerification() > X509ClientVerification::None)
{
std::optional<std::string> certificateUsername = sender->getUsernameFromPeerCertificate();

if (!certificateUsername || certificateUsername.value().empty())
throw ProtocolError("Client certificate did not provider username", ReasonCodes::BadUserNameOrPassword);

connectData.user_name_flag = true;
connectData.username = certificateUsername.value();
}

sender->setBridge(connectData.bridge);

if (this->protocolVersion == ProtocolVersion::None)
Expand Down Expand Up @@ -926,6 +937,11 @@ void MqttPacket::handleConnect()
{
authResult = AuthResult::success;
}
else if (sender->getX509ClientVerification() == X509ClientVerification::X509IsEnough)
{
// The client will have been kicked out already if the certificate is not valid, so we can just approve it.
authResult = AuthResult::success;
}
else if (connectData.authenticationMethod.empty())
{
authResult = authentication.unPwdCheck(connectData.client_id, connectData.username, connectData.password, getUserProperties(), sender);
Expand Down
4 changes: 2 additions & 2 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ void testSsl(const std::string &fullchain, const std::string &privkey)
}
}

void testSslVerifyLocations(const std::string &caFile, const std::string &caDir)
void testSslVerifyLocations(const std::string &caFile, const std::string &caDir, const std::string &error)
{
if (!caFile.empty() && getFileSize(caFile) <= 0)
throw ConfigFileException(formatString("SSL 'ca_file' file '%s' is empty or invalid", caFile.c_str()));
Expand All @@ -576,7 +576,7 @@ void testSslVerifyLocations(const std::string &caFile, const std::string &caDir)
if (SSL_CTX_load_verify_locations(sslCtx.get(), ca_file, ca_dir) != 1)
{
ERR_print_errors_cb(logSslError, NULL);
throw std::runtime_error("Loading ca_file/ca_dir failed.");
throw ConfigFileException(error);
}
}

Expand Down
2 changes: 1 addition & 1 deletion utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ std::string generateBadHttpRequestReponse(const std::string &msg);
std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol);

void testSsl(const std::string &fullchain, const std::string &privkey);
void testSslVerifyLocations(const std::string &caFile, const std::string &caDir);
void testSslVerifyLocations(const std::string &caFile, const std::string &caDir, const std::string &error);

std::string formatString(const std::string str, ...);

Expand Down
Loading

0 comments on commit de2aa39

Please sign in to comment.