12
12
#include < lwip/sockets.h>
13
13
#include < lwip/sys.h>
14
14
#include < lwip/netdb.h>
15
+ #include < mbedtls/sha256.h>
16
+ #include < mbedtls/oid.h>
17
+ #include < algorithm>
18
+ #include < string>
15
19
#include " ssl_client.h"
16
20
17
21
@@ -262,3 +266,145 @@ int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length)
262
266
// log_v( "%d bytes read", ret); //for low level debug
263
267
return ret;
264
268
}
269
+
270
+ static bool parseHexNibble (char pb, uint8_t * res)
271
+ {
272
+ if (pb >= ' 0' && pb <= ' 9' ) {
273
+ *res = (uint8_t ) (pb - ' 0' ); return true ;
274
+ } else if (pb >= ' a' && pb <= ' f' ) {
275
+ *res = (uint8_t ) (pb - ' a' + 10 ); return true ;
276
+ } else if (pb >= ' A' && pb <= ' F' ) {
277
+ *res = (uint8_t ) (pb - ' A' + 10 ); return true ;
278
+ }
279
+ return false ;
280
+ }
281
+
282
+ // Compare a name from certificate and domain name, return true if they match
283
+ static bool matchName (const std::string& name, const std::string& domainName)
284
+ {
285
+ size_t wildcardPos = name.find (' *' );
286
+ if (wildcardPos == std::string::npos) {
287
+ // Not a wildcard, expect an exact match
288
+ return name == domainName;
289
+ }
290
+
291
+ size_t firstDotPos = name.find (' .' );
292
+ if (wildcardPos > firstDotPos) {
293
+ // Wildcard is not part of leftmost component of domain name
294
+ // Do not attempt to match (rfc6125 6.4.3.1)
295
+ return false ;
296
+ }
297
+ if (wildcardPos != 0 || firstDotPos != 1 ) {
298
+ // Matching of wildcards such as baz*.example.com and b*z.example.com
299
+ // is optional. Maybe implement this in the future?
300
+ return false ;
301
+ }
302
+ size_t domainNameFirstDotPos = domainName.find (' .' );
303
+ if (domainNameFirstDotPos == std::string::npos) {
304
+ return false ;
305
+ }
306
+ return domainName.substr (domainNameFirstDotPos) == name.substr (firstDotPos);
307
+ }
308
+
309
+ // Verifies certificate provided by the peer to match specified SHA256 fingerprint
310
+ bool verify_ssl_fingerprint (sslclient_context *ssl_client, const char * fp, const char * domain_name)
311
+ {
312
+ // Convert hex string to byte array
313
+ uint8_t fingerprint_local[32 ];
314
+ int len = strlen (fp);
315
+ int pos = 0 ;
316
+ for (size_t i = 0 ; i < sizeof (fingerprint_local); ++i) {
317
+ while (pos < len && ((fp[pos] == ' ' ) || (fp[pos] == ' :' ))) {
318
+ ++pos;
319
+ }
320
+ if (pos > len - 2 ) {
321
+ log_d (" pos:%d len:%d fingerprint too short" , pos, len);
322
+ return false ;
323
+ }
324
+ uint8_t high, low;
325
+ if (!parseHexNibble (fp[pos], &high) || !parseHexNibble (fp[pos+1 ], &low)) {
326
+ log_d (" pos:%d len:%d invalid hex sequence: %c%c" , pos, len, fp[pos], fp[pos+1 ]);
327
+ return false ;
328
+ }
329
+ pos += 2 ;
330
+ fingerprint_local[i] = low | (high << 4 );
331
+ }
332
+
333
+ // Get certificate provided by the peer
334
+ const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert (&ssl_client->ssl_ctx );
335
+
336
+ if (!crt)
337
+ {
338
+ log_d (" could not fetch peer certificate" );
339
+ return false ;
340
+ }
341
+
342
+ // Calculate certificate's SHA256 fingerprint
343
+ uint8_t fingerprint_remote[32 ];
344
+ mbedtls_sha256_context sha256_ctx;
345
+ mbedtls_sha256_init (&sha256_ctx);
346
+ mbedtls_sha256_starts (&sha256_ctx, false );
347
+ mbedtls_sha256_update (&sha256_ctx, crt->raw .p , crt->raw .len );
348
+ mbedtls_sha256_finish (&sha256_ctx, fingerprint_remote);
349
+
350
+ // Check if fingerprints match
351
+ if (memcmp (fingerprint_local, fingerprint_remote, 32 ))
352
+ {
353
+ log_d (" fingerprint doesn't match" );
354
+ return false ;
355
+ }
356
+
357
+ // Additionally check if certificate has domain name if provided
358
+ if (domain_name)
359
+ return verify_ssl_dn (ssl_client, domain_name);
360
+ else
361
+ return true ;
362
+ }
363
+
364
+ // Checks if peer certificate has specified domain in CN or SANs
365
+ bool verify_ssl_dn (sslclient_context *ssl_client, const char * domain_name)
366
+ {
367
+ log_d (" domain name: '%s'" , (domain_name)?domain_name:" (null)" );
368
+ std::string domain_name_str (domain_name);
369
+ std::transform (domain_name_str.begin (), domain_name_str.end (), domain_name_str.begin (), ::tolower);
370
+
371
+ // Get certificate provided by the peer
372
+ const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert (&ssl_client->ssl_ctx );
373
+
374
+ // Check for domain name in SANs
375
+ const mbedtls_x509_sequence* san = &crt->subject_alt_names ;
376
+ while (san != nullptr )
377
+ {
378
+ std::string san_str ((const char *)san->buf .p , san->buf .len );
379
+ std::transform (san_str.begin (), san_str.end (), san_str.begin (), ::tolower);
380
+
381
+ if (matchName (san_str, domain_name_str))
382
+ return true ;
383
+
384
+ log_d (" SAN '%s': no match" , san_str.c_str ());
385
+
386
+ // Fetch next SAN
387
+ san = san->next ;
388
+ }
389
+
390
+ // Check for domain name in CN
391
+ const mbedtls_asn1_named_data* common_name = &crt->subject ;
392
+ while (common_name != nullptr )
393
+ {
394
+ // While iterating through DN objects, check for CN object
395
+ if (!MBEDTLS_OID_CMP (MBEDTLS_OID_AT_CN, &common_name->oid ))
396
+ {
397
+ std::string common_name_str ((const char *)common_name->val .p , common_name->val .len );
398
+
399
+ if (matchName (common_name_str, domain_name_str))
400
+ return true ;
401
+
402
+ log_d (" CN '%s': not match" , common_name_str.c_str ());
403
+ }
404
+
405
+ // Fetch next DN object
406
+ common_name = common_name->next ;
407
+ }
408
+
409
+ return false ;
410
+ }
0 commit comments