diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 5a59f372ad3..9b631fb397a 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1780,8 +1780,6 @@ def test_networked_bad_cert(self): h.request('GET', '/') self.assertEqual(exc_info.exception.reason, 'CERTIFICATE_VERIFY_FAILED') - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipIf(sys.platform == 'darwin', 'Occasionally success on macOS') def test_local_unknown_cert(self): # The custom cert isn't known to the default trust bundle diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index fd6d3e8a598..cc6f76a1a64 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -107,7 +107,7 @@ gethostname = "1.0.2" socket2 = { version = "0.6.0", features = ["all"] } dns-lookup = "3.0" openssl = { version = "0.10.72", optional = true } -openssl-sys = { version = "0.9.80", optional = true } +openssl-sys = { version = "0.9.110", optional = true } openssl-probe = { version = "0.1.5", optional = true } foreign-types-shared = { version = "0.1.1", optional = true } diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index c9a9e15f8e3..bf77d6b6907 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -26,7 +26,7 @@ cfg_if::cfg_if! { } #[allow(non_upper_case_globals)] -#[pymodule(with(ossl101, windows))] +#[pymodule(with(ossl101, ossl111, windows))] mod _ssl { use super::{bio, probe}; use crate::{ @@ -39,7 +39,7 @@ mod _ssl { socket::{self, PySocket}, vm::{ Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak}, + builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef, PyWeak}, class_or_notimplemented, convert::{ToPyException, ToPyObject}, exceptions, @@ -66,7 +66,8 @@ mod _ssl { ffi::CStr, fmt, io::{Read, Write}, - path::Path, + path::{Path, PathBuf}, + sync::LazyLock, time::Instant, }; @@ -84,19 +85,24 @@ mod _ssl { SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, - // #ifdef SSL_OP_SINGLE_ECDH_USE - // SSL_OP_SINGLE_ECDH_USE as OP_SINGLE_ECDH_USE - // #endif // X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF, // sys::X509_V_FLAG_CRL_CHECK|sys::X509_V_FLAG_CRL_CHECK_ALL as VERIFY_CRL_CHECK_CHAIN // X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, SSL_ERROR_ZERO_RETURN, SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE, + SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT, + SSL_OP_LEGACY_SERVER_CONNECT as OP_LEGACY_SERVER_CONNECT, SSL_OP_NO_SSLv2 as OP_NO_SSLv2, SSL_OP_NO_SSLv3 as OP_NO_SSLv3, SSL_OP_NO_TICKET as OP_NO_TICKET, SSL_OP_NO_TLSv1 as OP_NO_TLSv1, SSL_OP_SINGLE_DH_USE as OP_SINGLE_DH_USE, + SSL_OP_SINGLE_ECDH_USE as OP_SINGLE_ECDH_USE, + X509_V_FLAG_ALLOW_PROXY_CERTS as VERIFY_ALLOW_PROXY_CERTS, + X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF, + X509_V_FLAG_PARTIAL_CHAIN as VERIFY_X509_PARTIAL_CHAIN, + X509_V_FLAG_TRUSTED_FIRST as VERIFY_X509_TRUSTED_FIRST, + X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, }; // taken from CPython, should probably be kept up to date with their version if it ever changes @@ -116,6 +122,10 @@ mod _ssl { #[pyattr] const PROTOCOL_TLSv1: u32 = SslVersion::Tls1 as u32; #[pyattr] + const PROTOCOL_TLSv1_1: u32 = SslVersion::Tls1_1 as u32; + #[pyattr] + const PROTOCOL_TLSv1_2: u32 = SslVersion::Tls1_2 as u32; + #[pyattr] const PROTO_MINIMUM_SUPPORTED: i32 = ProtoVersion::MinSupported as i32; #[pyattr] const PROTO_SSLv3: i32 = ProtoVersion::Ssl3 as i32; @@ -146,7 +156,7 @@ mod _ssl { #[pyattr] const HAS_SNI: bool = true; #[pyattr] - const HAS_ECDH: bool = false; + const HAS_ECDH: bool = true; #[pyattr] const HAS_NPN: bool = false; #[pyattr] @@ -183,7 +193,8 @@ mod _ssl { #[pyattr(name = "_OPENSSL_API_VERSION")] fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo { - let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16).unwrap(); + let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16) + .expect("OPENSSL_API_VERSION is malformed"); parse_version_info(openssl_api_version) } @@ -241,7 +252,8 @@ mod _ssl { /// SSL/TLS connection terminated abruptly. #[pyattr(name = "SSLEOFError", once)] fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef { - PyType::new_simple_heap("ssl.SSLEOFError", &ssl_error(vm), &vm.ctx).unwrap() + vm.ctx + .new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)])) } type OpensslVersionInfo = (u8, u8, u8, u8, u8); @@ -265,7 +277,8 @@ mod _ssl { Ssl3 = 1, Tls, Tls1, - // TODO: Tls1_1, Tls1_2 ? + Tls1_1, + Tls1_2, TlsClient = 0x10, TlsServer, } @@ -341,14 +354,17 @@ mod _ssl { } type PyNid = (libc::c_int, String, String, Option); - fn obj2py(obj: &Asn1ObjectRef) -> PyNid { + fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult { let nid = obj.nid(); - ( - nid.as_raw(), - nid.short_name().unwrap().to_owned(), - nid.long_name().unwrap().to_owned(), - obj2txt(obj, true), - ) + let short_name = nid + .short_name() + .map_err(|_| vm.new_value_error("NID has no short name".to_owned()))? + .to_owned(); + let long_name = nid + .long_name() + .map_err(|_| vm.new_value_error("NID has no long name".to_owned()))? + .to_owned(); + Ok((nid.as_raw(), short_name, long_name, obj2txt(obj, true))) } #[derive(FromArgs)] @@ -362,55 +378,81 @@ mod _ssl { fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { _txt2obj(&args.txt.to_cstring(vm)?, !args.name) .as_deref() - .map(obj2py) .ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt))) + .and_then(|obj| obj2py(obj, vm)) } #[pyfunction] fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult { _nid2obj(Nid::from_raw(nid)) .as_deref() - .map(obj2py) .ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}"))) + .and_then(|obj| obj2py(obj, vm)) } - fn get_cert_file_dir() -> (&'static Path, &'static Path) { - let probe = probe(); - // on windows, these should be utf8 strings - fn path_from_bytes(c: &CStr) -> &Path { + // Lazily compute and cache cert file/dir paths + static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| { + fn path_from_cstr(c: &CStr) -> PathBuf { #[cfg(unix)] { use std::os::unix::ffi::OsStrExt; - std::ffi::OsStr::from_bytes(c.to_bytes()).as_ref() + std::ffi::OsStr::from_bytes(c.to_bytes()).into() } #[cfg(windows)] { - c.to_str().unwrap().as_ref() + // Use lossy conversion for potential non-UTF8 + PathBuf::from(c.to_string_lossy().as_ref()) } } - let cert_file = probe.cert_file.as_deref().unwrap_or_else(|| { - path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) - }); - let cert_dir = probe.cert_dir.as_deref().unwrap_or_else(|| { - path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) - }); + + let probe = probe(); + let cert_file = probe + .cert_file + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| { + path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) + }); + let cert_dir = probe + .cert_dir + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| { + path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) + }); (cert_file, cert_dir) + }); + + fn get_cert_file_dir() -> (&'static Path, &'static Path) { + let (cert_file, cert_dir) = &*CERT_PATHS; + (cert_file.as_path(), cert_dir.as_path()) } + // Lazily compute and cache cert environment variable names + static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| { + let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } + .to_string_lossy() + .into_owned(); + let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } + .to_string_lossy() + .into_owned(); + (cert_file_env, cert_dir_env) + }); + #[pyfunction] fn get_default_verify_paths( vm: &VirtualMachine, ) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> { - let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } - .to_str() - .unwrap(); - let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } - .to_str() - .unwrap(); + let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES; let (cert_file, cert_dir) = get_cert_file_dir(); let cert_file = OsPath::new_str(cert_file).filename(vm); let cert_dir = OsPath::new_str(cert_dir).filename(vm); - Ok((cert_file_env, cert_file, cert_dir_env, cert_dir)) + Ok(( + cert_file_env.as_str(), + cert_file, + cert_dir_env.as_str(), + cert_dir, + )) } #[pyfunction(name = "RAND_status")] @@ -480,7 +522,9 @@ mod _ssl { let method = match proto { // SslVersion::Ssl3 => unsafe { ssl::SslMethod::from_ptr(sys::SSLv3_method()) }, SslVersion::Tls => ssl::SslMethod::tls(), - // TODO: Tls1_1, Tls1_2 ? + SslVersion::Tls1 => ssl::SslMethod::tls(), + SslVersion::Tls1_1 => ssl::SslMethod::tls(), + SslVersion::Tls1_2 => ssl::SslMethod::tls(), SslVersion::TlsClient => ssl::SslMethod::tls_client(), SslVersion::TlsServer => ssl::SslMethod::tls_server(), _ => return Err(vm.new_value_error("invalid protocol version")), @@ -509,6 +553,7 @@ mod _ssl { options |= SslOptions::CIPHER_SERVER_PREFERENCE; options |= SslOptions::SINGLE_DH_USE; options |= SslOptions::SINGLE_ECDH_USE; + options |= SslOptions::ENABLE_MIDDLEBOX_COMPAT; builder.set_options(options); let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY; @@ -523,6 +568,13 @@ mod _ssl { .set_session_id_context(b"Python") .map_err(|e| convert_openssl_error(vm, e))?; + // Set default verify flags: VERIFY_X509_TRUSTED_FIRST + unsafe { + let ctx_ptr = builder.as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST); + } + PySslContext { ctx: PyRwLock::new(builder), check_hostname: AtomicCell::new(check_hostname), @@ -569,6 +621,64 @@ mod _ssl { }) } + #[pymethod] + fn get_ciphers(&self, vm: &VirtualMachine) -> PyResult { + let ctx = self.ctx(); + let ssl = ssl::Ssl::new(&ctx).map_err(|e| convert_openssl_error(vm, e))?; + + unsafe { + let ciphers_ptr = SSL_get_ciphers(ssl.as_ptr()); + if ciphers_ptr.is_null() { + return Ok(vm.ctx.new_list(vec![])); + } + + let num_ciphers = sys::OPENSSL_sk_num(ciphers_ptr as *const _); + let mut result = Vec::new(); + + for i in 0..num_ciphers { + let cipher_ptr = + sys::OPENSSL_sk_value(ciphers_ptr as *const _, i) as *const sys::SSL_CIPHER; + let cipher = ssl::SslCipherRef::from_ptr(cipher_ptr as *mut _); + + let (name, version, bits) = cipher_to_tuple(cipher); + let dict = vm.ctx.new_dict(); + dict.set_item("name", vm.ctx.new_str(name).into(), vm)?; + dict.set_item("protocol", vm.ctx.new_str(version).into(), vm)?; + dict.set_item("secret_bits", vm.ctx.new_int(bits).into(), vm)?; + result.push(dict.into()); + } + + Ok(vm.ctx.new_list(result)) + } + } + + #[pymethod] + fn set_ecdh_curve(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { + use openssl::ec::{EcGroup, EcKey}; + + let curve_name = name.as_str(); + if curve_name.contains('\0') { + return Err(exceptions::cstring_error(vm)); + } + + // Find the NID for the curve name using OBJ_sn2nid + let name_cstr = name.to_cstring(vm)?; + let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) }; + if nid_raw == 0 { + return Err(vm.new_value_error(format!("unknown curve name: {}", curve_name))); + } + let nid = Nid::from_raw(nid_raw); + + // Create EC key from the curve + let group = EcGroup::from_curve_name(nid).map_err(|e| convert_openssl_error(vm, e))?; + let key = EcKey::from_group(&group).map_err(|e| convert_openssl_error(vm, e))?; + + // Set the temporary ECDH key + self.builder() + .set_tmp_ecdh(&key) + .map_err(|e| convert_openssl_error(vm, e)) + } + #[pygetset] fn options(&self) -> libc::c_ulong { self.ctx.read().options().bits() as _ @@ -616,6 +726,38 @@ mod _ssl { Ok(()) } #[pygetset] + fn verify_flags(&self) -> libc::c_ulong { + unsafe { + let ctx_ptr = self.ctx().as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + sys::X509_VERIFY_PARAM_get_flags(param) + } + } + #[pygetset(setter)] + fn set_verify_flags(&self, new_flags: libc::c_ulong, vm: &VirtualMachine) -> PyResult<()> { + unsafe { + let ctx_ptr = self.ctx().as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + let flags = sys::X509_VERIFY_PARAM_get_flags(param); + let clear = flags & !new_flags; + let set = !flags & new_flags; + + if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 { + return Err(vm.new_exception_msg( + ssl_error(vm), + "Failed to clear verify flags".to_owned(), + )); + } + if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 { + return Err(vm.new_exception_msg( + ssl_error(vm), + "Failed to set verify flags".to_owned(), + )); + } + Ok(()) + } + } + #[pygetset] fn check_hostname(&self) -> bool { self.check_hostname.load() } @@ -743,8 +885,16 @@ mod _ssl { let certs = ctx.cert_store().all_certificates(); #[cfg(not(ossl300))] let certs = ctx.cert_store().objects().iter().filter_map(|x| x.x509()); + + // Filter to only include CA certificates (Basic Constraints: CA=TRUE) let certs = certs .into_iter() + .filter(|cert| { + unsafe { + // X509_check_ca() returns 1 for CA certificates + X509_check_ca(cert.as_ptr()) == 1 + } + }) .map(|ref cert| cert_to_py(vm, cert, binary_form)) .collect::, _>>()?; Ok(certs) @@ -781,6 +931,20 @@ mod _ssl { args: WrapSocketArgs, vm: &VirtualMachine, ) -> PyResult { + // validate socket type and context protocol + if !args.server_side && zelf.protocol == SslVersion::TlsServer { + return Err(vm.new_exception_msg( + ssl_error(vm), + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + )); + } + if args.server_side && zelf.protocol == SslVersion::TlsClient { + return Err(vm.new_exception_msg( + ssl_error(vm), + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + )); + } + let mut ssl = ssl::Ssl::new(&zelf.ctx()).map_err(|e| convert_openssl_error(vm, e))?; let socket_type = if args.server_side { @@ -1041,6 +1205,37 @@ mod _ssl { Some(vm.ctx.new_list(certs).into()) } + #[pymethod] + fn get_verified_chain(&self, vm: &VirtualMachine) -> Option { + let stream = self.stream.read(); + unsafe { + let chain = sys::SSL_get0_verified_chain(stream.ssl().as_ptr()); + if chain.is_null() { + return None; + } + + let num_certs = sys::OPENSSL_sk_num(chain as *const _); + let mut certs = Vec::new(); + + for i in 0..num_certs { + let cert_ptr = sys::OPENSSL_sk_value(chain as *const _, i) as *mut sys::X509; + if cert_ptr.is_null() { + continue; + } + let cert = X509Ref::from_ptr(cert_ptr); + if let Ok(der) = cert.to_der() { + certs.push(vm.ctx.new_bytes(der).into()); + } + } + + if certs.is_empty() { + None + } else { + Some(vm.ctx.new_list(certs)) + } + } + } + #[pymethod] fn version(&self) -> Option<&'static str> { let v = self.stream.read().ssl().version_str(); @@ -1056,6 +1251,189 @@ mod _ssl { .map(cipher_to_tuple) } + #[pymethod] + fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { + #[cfg(ossl110)] + { + let stream = self.stream.read(); + unsafe { + let server_ciphers = SSL_get_ciphers(stream.ssl().as_ptr()); + if server_ciphers.is_null() { + return None; + } + + let client_ciphers = SSL_get_client_ciphers(stream.ssl().as_ptr()); + if client_ciphers.is_null() { + return None; + } + + let mut result = Vec::new(); + let num_server = sys::OPENSSL_sk_num(server_ciphers as *const _); + let num_client = sys::OPENSSL_sk_num(client_ciphers as *const _); + + for i in 0..num_server { + let server_cipher_ptr = sys::OPENSSL_sk_value(server_ciphers as *const _, i) + as *const sys::SSL_CIPHER; + + // Check if client supports this cipher by comparing pointers + let mut found = false; + for j in 0..num_client { + let client_cipher_ptr = + sys::OPENSSL_sk_value(client_ciphers as *const _, j) + as *const sys::SSL_CIPHER; + + if server_cipher_ptr == client_cipher_ptr { + found = true; + break; + } + } + + if found { + let cipher = ssl::SslCipherRef::from_ptr(server_cipher_ptr as *mut _); + let (name, version, bits) = cipher_to_tuple(cipher); + let tuple = vm.new_tuple(( + vm.ctx.new_str(name), + vm.ctx.new_str(version), + vm.ctx.new_int(bits), + )); + result.push(tuple.into()); + } + } + + if result.is_empty() { + None + } else { + Some(vm.ctx.new_list(result)) + } + } + } + #[cfg(not(ossl110))] + { + let _ = vm; + None + } + } + + #[pymethod] + fn selected_alpn_protocol(&self) -> Option { + #[cfg(ossl102)] + { + let stream = self.stream.read(); + unsafe { + let mut out: *const libc::c_uchar = std::ptr::null(); + let mut outlen: libc::c_uint = 0; + + sys::SSL_get0_alpn_selected(stream.ssl().as_ptr(), &mut out, &mut outlen); + + if out.is_null() { + None + } else { + let slice = std::slice::from_raw_parts(out, outlen as usize); + Some(String::from_utf8_lossy(slice).into_owned()) + } + } + } + #[cfg(not(ossl102))] + { + None + } + } + + #[pymethod] + fn get_channel_binding( + &self, + cb_type: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + const CB_MAXLEN: usize = 512; + + let cb_type_str = cb_type.as_ref().map_or("tls-unique", |s| s.as_str()); + + if cb_type_str != "tls-unique" { + return Err(vm.new_value_error(format!( + "Unsupported channel binding type '{}'", + cb_type_str + ))); + } + + let stream = self.stream.read(); + let ssl_ptr = stream.ssl().as_ptr(); + + unsafe { + let session_reused = sys::SSL_session_reused(ssl_ptr) != 0; + let is_client = matches!(self.socket_type, SslServerOrClient::Client); + + // Use XOR logic from CPython + let use_finished = session_reused ^ is_client; + + let mut buf = vec![0u8; CB_MAXLEN]; + let len = if use_finished { + sys::SSL_get_finished(ssl_ptr, buf.as_mut_ptr() as *mut _, CB_MAXLEN) + } else { + sys::SSL_get_peer_finished(ssl_ptr, buf.as_mut_ptr() as *mut _, CB_MAXLEN) + }; + + if len == 0 { + Ok(None) + } else { + buf.truncate(len); + Ok(Some(vm.ctx.new_bytes(buf))) + } + } + } + + #[pymethod] + fn verify_client_post_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + #[cfg(ossl111)] + { + let stream = self.stream.read(); + let result = unsafe { SSL_verify_client_post_handshake(stream.ssl().as_ptr()) }; + if result == 0 { + Err(vm.new_exception_msg( + ssl_error(vm), + "Post-handshake authentication failed".to_owned(), + )) + } else { + Ok(()) + } + } + #[cfg(not(ossl111))] + { + Err(vm.new_not_implemented_error( + "Post-handshake auth is not supported by your OpenSSL version.".to_owned(), + )) + } + } + + #[pymethod] + fn shutdown(&self, vm: &VirtualMachine) -> PyResult> { + let stream = self.stream.read(); + let ssl_ptr = stream.ssl().as_ptr(); + + // Perform SSL shutdown + let ret = unsafe { sys::SSL_shutdown(ssl_ptr) }; + + if ret < 0 { + // Error occurred + let err = unsafe { sys::SSL_get_error(ssl_ptr, ret) }; + + if err == sys::SSL_ERROR_WANT_READ || err == sys::SSL_ERROR_WANT_WRITE { + // Non-blocking would block - this is okay for shutdown + // Return the underlying socket + } else { + return Err(vm.new_exception_msg( + ssl_error(vm), + format!("SSL shutdown failed: error code {}", err), + )); + } + } + + // Return the underlying socket + // Get the socket from the stream (SocketStream wraps PyRef) + let socket = stream.get_ref(); + Ok(socket.0.clone()) + } + #[cfg(osslconf = "OPENSSL_NO_COMP")] #[pymethod] fn compression(&self) -> Option<&'static str> { @@ -1267,7 +1645,7 @@ mod _ssl { let ret = match inner_buffer { Either::A(_buf) => vm.ctx.new_int(count).into(), Either::B(mut buf) => { - buf.truncate(n); + buf.truncate(count); buf.shrink_to_fit(); vm.ctx.new_bytes(buf).into() } @@ -1363,6 +1741,27 @@ mod _ssl { unsafe impl Send for PySslMemoryBio {} unsafe impl Sync for PySslMemoryBio {} + // OpenSSL functions not in openssl-sys + + unsafe extern "C" { + // X509_check_ca returns 1 for CA certificates, 0 otherwise + fn X509_check_ca(x: *const sys::X509) -> libc::c_int; + } + + unsafe extern "C" { + fn SSL_get_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; + } + + #[cfg(ossl110)] + unsafe extern "C" { + fn SSL_get_client_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; + } + + #[cfg(ossl111)] + unsafe extern "C" { + fn SSL_verify_client_post_handshake(ssl: *const sys::SSL) -> libc::c_int; + } + // OpenSSL BIO helper functions // These are typically macros in OpenSSL, implemented via BIO_ctrl const BIO_CTRL_PENDING: libc::c_int = 10; @@ -1525,12 +1924,12 @@ mod _ssl { } #[pygetset] - fn id(&self, vm: &VirtualMachine) -> PyObjectRef { + fn id(&self, vm: &VirtualMachine) -> PyBytesRef { unsafe { let mut len: libc::c_uint = 0; let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len); let id_slice = std::slice::from_raw_parts(id_ptr, len as usize); - vm.ctx.new_bytes(id_slice.to_vec()).into() + vm.ctx.new_bytes(id_slice.to_vec()) } } @@ -1568,23 +1967,39 @@ mod _ssl { "certificate verify failed" => "CERTIFICATE_VERIFY_FAILED", _ => default_errstr, }; - let msg = if let Some(lib) = e.library() { - // add `library` attribute - let attr_name = vm.ctx.as_ref().intern_str("library"); - cls.set_attr(attr_name, vm.ctx.new_str(lib).into()); + + // Build message + let lib_obj = e.library(); + let msg = if let Some(lib) = lib_obj { format!("[{lib}] {errstr} ({file}:{line})") } else { format!("{errstr} ({file}:{line})") }; - // add `reason` attribute - let attr_name = vm.ctx.as_ref().intern_str("reason"); - cls.set_attr(attr_name, vm.ctx.new_str(errstr).into()); + // Create exception instance let reason = sys::ERR_GET_REASON(e.code()); - vm.new_exception( + let exc = vm.new_exception( cls, vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()], - ) + ); + + // Set attributes on instance, not class + let exc_obj: PyObjectRef = exc.into(); + + // Set reason attribute (always set, even if just the error string) + let reason_value = vm.ctx.new_str(errstr); + let _ = exc_obj.set_attr("reason", reason_value, vm); + + // Set library attribute (None if not available) + let library_value: PyObjectRef = if let Some(lib) = lib_obj { + vm.ctx.new_str(lib).into() + } else { + vm.ctx.none() + }; + let _ = exc_obj.set_attr("library", library_value, vm); + + // Convert back to PyBaseExceptionRef + exc_obj.downcast().unwrap() } None => vm.new_exception_empty(cls), } @@ -1681,7 +2096,8 @@ mod _ssl { dict.set_item("subject", name_to_py(cert.subject_name())?, vm)?; dict.set_item("issuer", name_to_py(cert.issuer_name())?, vm)?; - dict.set_item("version", vm.new_pyobj(cert.version()), vm)?; + // X.509 version: OpenSSL uses 0-based (0=v1, 1=v2, 2=v3) but Python uses 1-based (1=v1, 2=v2, 3=v3) + dict.set_item("version", vm.new_pyobj(cert.version() + 1), vm)?; let serial_num = cert .serial_number() @@ -1894,21 +2310,18 @@ mod windows { Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), other => vm.new_pyobj(other), }; - let usage: PyObjectRef = match c.valid_uses()? { + let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? { ValidUses::All => vm.ctx.new_bool(true).into(), ValidUses::Oids(oids) => PyFrozenSet::from_iter( vm, oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), - ) - .unwrap() + )? .into_ref(&vm.ctx) .into(), }; Ok(vm.new_tuple((cert, enc_type, usage)).into()) }); - let certs = certs - .collect::, _>>() - .map_err(|e: std::io::Error| e.to_pyexception(vm))?; + let certs: Vec = certs.collect::>>()?; Ok(certs) } }