diff --git a/tests/regress.c b/tests/regress.c index 5d069d7fc..003ef26fb 100644 --- a/tests/regress.c +++ b/tests/regress.c @@ -498,6 +498,16 @@ static int AcceptAnyServerHostKey(const byte* pubKey, word32 pubKeySz, return 0; } +static int RejectAnyServerHostKey(const byte* pubKey, word32 pubKeySz, + void* ctx) +{ + (void)pubKey; + (void)pubKeySz; + (void)ctx; + + return 1; +} + static int QueueAppend(DuplexQueue* queue, const byte* data, word32 dataSz) { if (queue == NULL || data == NULL) { @@ -938,6 +948,34 @@ static void TestKexDhReplyRejectsNoPublicKeyCheck(void) #endif } +static void AssertHandshakeRejectsWhenCallbackRejects(const char* keyAlgo) +{ + KexReplyHarness harness; + KexReplyRunResult result; + + InitKexReplyHarness(&harness, keyAlgo, 0); + wolfSSH_CTX_SetPublicKeyCheck(harness.clientCtx, RejectAnyServerHostKey); + RunKexReplyHandshake(&harness, &result); + + AssertFalse(result.clientSuccess); + AssertTrue(result.clientRet == WS_FATAL_ERROR); + AssertTrue(result.clientErr != WS_WANT_READ && result.clientErr != WS_WANT_WRITE); + AssertIntEQ(result.clientErr, WS_PUBKEY_REJECTED_E); + AssertFalse(harness.client->connectState >= CONNECT_KEYED); + + FreeKexReplyHarness(&harness); +} + +static void TestKexDhReplyRejectsWhenCallbackRejects(void) +{ +#ifndef WOLFSSH_NO_RSA_SHA2_256 + AssertHandshakeRejectsWhenCallbackRejects("rsa-sha2-256"); +#endif +#ifndef WOLFSSH_NO_RSA_SHA2_512 + AssertHandshakeRejectsWhenCallbackRejects("rsa-sha2-512"); +#endif +} + #endif /* KEXDH_REPLY_REGRESS_KEX_ALGO */ static void AssertChannelOpenFailResponse(const ChannelOpenHarness* harness, @@ -1727,6 +1765,7 @@ int main(int argc, char** argv) TestKexDhReplyRejectsRsaSha2_512SigNameDowngrade(); #endif TestKexDhReplyRejectsNoPublicKeyCheck(); + TestKexDhReplyRejectsWhenCallbackRejects(); #endif #ifdef WOLFSSH_SFTP