From 71afd0a03b7f8b6a6aaed2b49f8807cda68b4d13 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Mon, 10 Mar 2025 00:32:51 -0700 Subject: [PATCH] rewrite of winreg module and add test_winreg --- Cargo.lock | 80 ++- Lib/test/test_winreg.py | 552 +++++++++++++++++++++ vm/Cargo.toml | 1 - vm/src/stdlib/winreg.rs | 1044 ++++++++++++++++++++++++++++++--------- 4 files changed, 1401 insertions(+), 276 deletions(-) create mode 100644 Lib/test/test_winreg.py diff --git a/Cargo.lock b/Cargo.lock index 0255535c1b..610a2ead77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -115,9 +115,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "approx" @@ -213,9 +213,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata", @@ -283,9 +283,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.18" +version = "1.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" +checksum = "8e3a13707ac958681c13b39b458c073d0d9bc8a22cb1b2f4c8e55eb72c13f362" dependencies = [ "shlex", ] @@ -365,18 +365,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.36" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2df961d8c8a0d08aa9945718ccf584145eee3f3aa06cddbeac12933781102e04" +checksum = "eccb054f56cbd38340b380d4a8e69ef1f02f1af43db2f0cc817a4774d80ae071" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.36" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "132dbda40fb6753878316a489d5a1242a8ef2f0d9e47ba01c951ea8aa7d013a5" +checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2" dependencies = [ "anstyle", "clap_lex", @@ -1037,9 +1037,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -1219,9 +1219,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jiff" -version = "0.2.5" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c102670231191d07d37a35af3eb77f1f0dbf7a71be51a962dcd57ea607be7260" +checksum = "59ec30f7142be6fe14e1b021f50b85db8df2d4324ea6e91ec3e5dcde092021d0" dependencies = [ "jiff-static", "log", @@ -1232,9 +1232,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.5" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cdde31a9d349f1b1f51a0b3714a5940ac022976f4b49485fc04be052b183b4c" +checksum = "526b834d727fd59d37b076b0c3236d9adde1b1729a4361e20b2026f738cc1dbe" dependencies = [ "proc-macro2", "quote", @@ -1272,9 +1272,9 @@ dependencies = [ [[package]] name = "lambert_w" -version = "1.2.9" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd4d9b9fa6582f5d77f954729c91c32a7c85834332e470b014d12e1678fd1793" +checksum = "913e1e36ca541d75f384593fa70bf5a5e9f001f2996bfa7926550d921f83baf6" dependencies = [ "num-complex", "num-traits", @@ -1330,9 +1330,9 @@ checksum = "0864a00c8d019e36216b69c2c4ce50b83b7bd966add3cf5ba554ec44f8bebcf5" [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libffi" @@ -1407,9 +1407,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "lock_api" @@ -1557,9 +1557,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] @@ -1708,9 +1708,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-src" -version = "300.4.2+3.4.1" +version = "300.5.0+3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168ce4e058f975fe43e89d9ccf78ca668601887ae736090aacc23ae353c298e2" +checksum = "e8ce546f549326b0e6052b649198487d91320875da901e7bd11a06d1ee3f9c2f" dependencies = [ "cc", ] @@ -1912,9 +1912,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -2039,13 +2039,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy 0.8.24", ] [[package]] @@ -2368,7 +2367,7 @@ dependencies = [ name = "rustpython-compiler" version = "0.4.0" dependencies = [ - "rand 0.9.0", + "rand 0.9.1", "ruff_python_ast", "ruff_python_parser", "ruff_source_file", @@ -2457,7 +2456,7 @@ dependencies = [ "is-macro", "lexical-parse-float", "num-traits", - "rand 0.9.0", + "rand 0.9.1", "rustpython-wtf8", "unic-ucd-category", ] @@ -2633,7 +2632,6 @@ dependencies = [ "widestring", "windows", "windows-sys 0.59.0", - "winreg", ] [[package]] @@ -3684,16 +3682,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "winreg" -version = "0.55.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb5a765337c50e9ec252c2069be9bf91c7df47afb103b642ba3a53bf8101be97" -dependencies = [ - "cfg-if", - "windows-sys 0.59.0", -] - [[package]] name = "winsafe" version = "0.0.19" @@ -3711,9 +3699,9 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" +checksum = "a62ce76d9b56901b19a74f19431b0d8b3bc7ca4ad685a746dfd78ca8f4fc6bda" [[package]] name = "zerocopy" diff --git a/Lib/test/test_winreg.py b/Lib/test/test_winreg.py new file mode 100644 index 0000000000..2c530ee754 --- /dev/null +++ b/Lib/test/test_winreg.py @@ -0,0 +1,552 @@ +# Test the windows specific win32reg module. +# Only win32reg functions not hit here: FlushKey, LoadKey and SaveKey + +import gc +import os, sys, errno +import threading +import unittest +from platform import machine, win32_edition +from test.support import cpython_only, import_helper + +# Do this first so test will be skipped if module doesn't exist +import_helper.import_module('winreg', required_on=['win']) +# Now import everything +from winreg import * + +try: + REMOTE_NAME = sys.argv[sys.argv.index("--remote")+1] +except (IndexError, ValueError): + REMOTE_NAME = None + +# tuple of (major, minor) +WIN_VER = sys.getwindowsversion()[:2] +# Some tests should only run on 64-bit architectures where WOW64 will be. +WIN64_MACHINE = True if machine() == "AMD64" else False + +# Starting with Windows 7 and Windows Server 2008 R2, WOW64 no longer uses +# registry reflection and formerly reflected keys are shared instead. +# Windows 7 and Windows Server 2008 R2 are version 6.1. Due to this, some +# tests are only valid up until 6.1 +HAS_REFLECTION = True if WIN_VER < (6, 1) else False + +# Use a per-process key to prevent concurrent test runs (buildbot!) from +# stomping on each other. +test_key_base = "Python Test Key [%d] - Delete Me" % (os.getpid(),) +test_key_name = "SOFTWARE\\" + test_key_base +# On OS'es that support reflection we should test with a reflected key +test_reflect_key_name = "SOFTWARE\\Classes\\" + test_key_base + +test_data = [ + ("Int Value", 45, REG_DWORD), + ("Qword Value", 0x1122334455667788, REG_QWORD), + ("String Val", "A string value", REG_SZ), + ("StringExpand", "The path is %path%", REG_EXPAND_SZ), + ("Multi-string", ["Lots", "of", "string", "values"], REG_MULTI_SZ), + ("Multi-nul", ["", "", "", ""], REG_MULTI_SZ), + ("Raw Data", b"binary\x00data", REG_BINARY), + ("Big String", "x"*(2**14-1), REG_SZ), + ("Big Binary", b"x"*(2**14), REG_BINARY), + # Two and three kanjis, meaning: "Japan" and "Japanese". + ("Japanese 日本", "日本語", REG_SZ), +] + + +@cpython_only +class HeapTypeTests(unittest.TestCase): + def test_have_gc(self): + self.assertTrue(gc.is_tracked(HKEYType)) + + def test_immutable(self): + with self.assertRaisesRegex(TypeError, "immutable"): + HKEYType.foo = "bar" + + +class BaseWinregTests(unittest.TestCase): + + def setUp(self): + # Make sure that the test key is absent when the test + # starts. + self.delete_tree(HKEY_CURRENT_USER, test_key_name) + + def delete_tree(self, root, subkey): + try: + hkey = OpenKey(root, subkey, 0, KEY_ALL_ACCESS) + except OSError: + # subkey does not exist + return + while True: + try: + subsubkey = EnumKey(hkey, 0) + except OSError: + # no more subkeys + break + self.delete_tree(hkey, subsubkey) + CloseKey(hkey) + DeleteKey(root, subkey) + + def _write_test_data(self, root_key, subkeystr="sub_key", + CreateKey=CreateKey): + # Set the default value for this key. + SetValue(root_key, test_key_name, REG_SZ, "Default value") + key = CreateKey(root_key, test_key_name) + self.assertTrue(key.handle != 0) + # Create a sub-key + sub_key = CreateKey(key, subkeystr) + # Give the sub-key some named values + + for value_name, value_data, value_type in test_data: + SetValueEx(sub_key, value_name, 0, value_type, value_data) + + # Check we wrote as many items as we thought. + nkeys, nvalues, since_mod = QueryInfoKey(key) + self.assertEqual(nkeys, 1, "Not the correct number of sub keys") + self.assertEqual(nvalues, 1, "Not the correct number of values") + nkeys, nvalues, since_mod = QueryInfoKey(sub_key) + self.assertEqual(nkeys, 0, "Not the correct number of sub keys") + self.assertEqual(nvalues, len(test_data), + "Not the correct number of values") + # Close this key this way... + # (but before we do, copy the key as an integer - this allows + # us to test that the key really gets closed). + int_sub_key = int(sub_key) + CloseKey(sub_key) + try: + QueryInfoKey(int_sub_key) + self.fail("It appears the CloseKey() function does " + "not close the actual key!") + except OSError: + pass + # ... and close that key that way :-) + int_key = int(key) + key.Close() + try: + QueryInfoKey(int_key) + self.fail("It appears the key.Close() function " + "does not close the actual key!") + except OSError: + pass + def _read_test_data(self, root_key, subkeystr="sub_key", OpenKey=OpenKey): + # Check we can get default value for this key. + val = QueryValue(root_key, test_key_name) + self.assertEqual(val, "Default value", + "Registry didn't give back the correct value") + + key = OpenKey(root_key, test_key_name) + # Read the sub-keys + with OpenKey(key, subkeystr) as sub_key: + # Check I can enumerate over the values. + index = 0 + while 1: + try: + data = EnumValue(sub_key, index) + except OSError: + break + self.assertEqual(data in test_data, True, + "Didn't read back the correct test data") + index = index + 1 + self.assertEqual(index, len(test_data), + "Didn't read the correct number of items") + # Check I can directly access each item + for value_name, value_data, value_type in test_data: + read_val, read_typ = QueryValueEx(sub_key, value_name) + self.assertEqual(read_val, value_data, + "Could not directly read the value") + self.assertEqual(read_typ, value_type, + "Could not directly read the value") + sub_key.Close() + # Enumerate our main key. + read_val = EnumKey(key, 0) + self.assertEqual(read_val, subkeystr, "Read subkey value wrong") + try: + EnumKey(key, 1) + self.fail("Was able to get a second key when I only have one!") + except OSError: + pass + + key.Close() + + def _delete_test_data(self, root_key, subkeystr="sub_key"): + key = OpenKey(root_key, test_key_name, 0, KEY_ALL_ACCESS) + sub_key = OpenKey(key, subkeystr, 0, KEY_ALL_ACCESS) + # It is not necessary to delete the values before deleting + # the key (although subkeys must not exist). We delete them + # manually just to prove we can :-) + for value_name, value_data, value_type in test_data: + DeleteValue(sub_key, value_name) + + nkeys, nvalues, since_mod = QueryInfoKey(sub_key) + self.assertEqual(nkeys, 0, "subkey not empty before delete") + self.assertEqual(nvalues, 0, "subkey not empty before delete") + sub_key.Close() + DeleteKey(key, subkeystr) + + try: + # Shouldn't be able to delete it twice! + DeleteKey(key, subkeystr) + self.fail("Deleting the key twice succeeded") + except OSError: + pass + key.Close() + DeleteKey(root_key, test_key_name) + # Opening should now fail! + try: + key = OpenKey(root_key, test_key_name) + self.fail("Could open the non-existent key") + except OSError: # Use this error name this time + pass + + def _test_all(self, root_key, subkeystr="sub_key"): + self._write_test_data(root_key, subkeystr) + self._read_test_data(root_key, subkeystr) + self._delete_test_data(root_key, subkeystr) + + def _test_named_args(self, key, sub_key): + with CreateKeyEx(key=key, sub_key=sub_key, reserved=0, + access=KEY_ALL_ACCESS) as ckey: + self.assertTrue(ckey.handle != 0) + + with OpenKeyEx(key=key, sub_key=sub_key, reserved=0, + access=KEY_ALL_ACCESS) as okey: + self.assertTrue(okey.handle != 0) + + +class LocalWinregTests(BaseWinregTests): + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_registry_works(self): + self._test_all(HKEY_CURRENT_USER) + self._test_all(HKEY_CURRENT_USER, "日本-subkey") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_registry_works_extended_functions(self): + # Substitute the regular CreateKey and OpenKey calls with their + # extended counterparts. + # Note: DeleteKeyEx is not used here because it is platform dependent + cke = lambda key, sub_key: CreateKeyEx(key, sub_key, 0, KEY_ALL_ACCESS) + self._write_test_data(HKEY_CURRENT_USER, CreateKey=cke) + + oke = lambda key, sub_key: OpenKeyEx(key, sub_key, 0, KEY_READ) + self._read_test_data(HKEY_CURRENT_USER, OpenKey=oke) + + self._delete_test_data(HKEY_CURRENT_USER) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_named_arguments(self): + self._test_named_args(HKEY_CURRENT_USER, test_key_name) + # Use the regular DeleteKey to clean up + # DeleteKeyEx takes named args and is tested separately + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_connect_registry_to_local_machine_works(self): + # perform minimal ConnectRegistry test which just invokes it + h = ConnectRegistry(None, HKEY_LOCAL_MACHINE) + self.assertNotEqual(h.handle, 0) + h.Close() + self.assertEqual(h.handle, 0) + + def test_nonexistent_remote_registry(self): + connect = lambda: ConnectRegistry("abcdefghijkl", HKEY_CURRENT_USER) + self.assertRaises(OSError, connect) + + # TODO: RUSTPYTHON + @unittest.skip("flaky") + def testExpandEnvironmentStrings(self): + r = ExpandEnvironmentStrings("%windir%\\test") + self.assertEqual(type(r), str) + self.assertEqual(r, os.environ["windir"] + "\\test") + + def test_context_manager(self): + # ensure that the handle is closed if an exception occurs + try: + with ConnectRegistry(None, HKEY_LOCAL_MACHINE) as h: + self.assertNotEqual(h.handle, 0) + raise OSError + except OSError: + self.assertEqual(h.handle, 0) + + def test_changing_value(self): + # Issue2810: A race condition in 2.6 and 3.1 may cause + # EnumValue or QueryValue to raise "WindowsError: More data is + # available" + done = False + + class VeryActiveThread(threading.Thread): + def run(self): + with CreateKey(HKEY_CURRENT_USER, test_key_name) as key: + use_short = True + long_string = 'x'*2000 + while not done: + s = 'x' if use_short else long_string + use_short = not use_short + SetValue(key, 'changing_value', REG_SZ, s) + + thread = VeryActiveThread() + thread.start() + try: + with CreateKey(HKEY_CURRENT_USER, + test_key_name+'\\changing_value') as key: + for _ in range(1000): + num_subkeys, num_values, t = QueryInfoKey(key) + for i in range(num_values): + name = EnumValue(key, i) + QueryValue(key, name[0]) + finally: + done = True + thread.join() + DeleteKey(HKEY_CURRENT_USER, test_key_name+'\\changing_value') + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_long_key(self): + # Issue2810, in 2.6 and 3.1 when the key name was exactly 256 + # characters, EnumKey raised "WindowsError: More data is + # available" + name = 'x'*256 + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as key: + SetValue(key, name, REG_SZ, 'x') + num_subkeys, num_values, t = QueryInfoKey(key) + EnumKey(key, 0) + finally: + DeleteKey(HKEY_CURRENT_USER, '\\'.join((test_key_name, name))) + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_dynamic_key(self): + # Issue2810, when the value is dynamically generated, these + # raise "WindowsError: More data is available" in 2.6 and 3.1 + try: + EnumValue(HKEY_PERFORMANCE_DATA, 0) + except OSError as e: + if e.errno in (errno.EPERM, errno.EACCES): + self.skipTest("access denied to registry key " + "(are you running in a non-interactive session?)") + raise + QueryValueEx(HKEY_PERFORMANCE_DATA, "") + + # Reflection requires XP x64/Vista at a minimum. XP doesn't have this stuff + # or DeleteKeyEx so make sure their use raises NotImplementedError + @unittest.skipUnless(WIN_VER < (5, 2), "Requires Windows XP") + def test_reflection_unsupported(self): + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + + key = OpenKey(HKEY_CURRENT_USER, test_key_name) + self.assertNotEqual(key.handle, 0) + + with self.assertRaises(NotImplementedError): + DisableReflectionKey(key) + with self.assertRaises(NotImplementedError): + EnableReflectionKey(key) + with self.assertRaises(NotImplementedError): + QueryReflectionKey(key) + with self.assertRaises(NotImplementedError): + DeleteKeyEx(HKEY_CURRENT_USER, test_key_name) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_setvalueex_value_range(self): + # Test for Issue #14420, accept proper ranges for SetValueEx. + # Py2Reg, which gets called by SetValueEx, was using PyLong_AsLong, + # thus raising OverflowError. The implementation now uses + # PyLong_AsUnsignedLong to match DWORD's size. + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + SetValueEx(ck, "test_name", None, REG_DWORD, 0x80000000) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_setvalueex_negative_one_check(self): + # Test for Issue #43984, check -1 was not set by SetValueEx. + # Py2Reg, which gets called by SetValueEx, wasn't checking return + # value by PyLong_AsUnsignedLong, thus setting -1 as value in the registry. + # The implementation now checks PyLong_AsUnsignedLong return value to assure + # the value set was not -1. + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + with self.assertRaises(OverflowError): + SetValueEx(ck, "test_name_dword", None, REG_DWORD, -1) + SetValueEx(ck, "test_name_qword", None, REG_QWORD, -1) + self.assertRaises(FileNotFoundError, QueryValueEx, ck, "test_name_dword") + self.assertRaises(FileNotFoundError, QueryValueEx, ck, "test_name_qword") + + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_queryvalueex_return_value(self): + # Test for Issue #16759, return unsigned int from QueryValueEx. + # Reg2Py, which gets called by QueryValueEx, was returning a value + # generated by PyLong_FromLong. The implementation now uses + # PyLong_FromUnsignedLong to match DWORD's size. + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + test_val = 0x80000000 + SetValueEx(ck, "test_name", None, REG_DWORD, test_val) + ret_val, ret_type = QueryValueEx(ck, "test_name") + self.assertEqual(ret_type, REG_DWORD) + self.assertEqual(ret_val, test_val) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_setvalueex_crash_with_none_arg(self): + # Test for Issue #21151, segfault when None is passed to SetValueEx + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + test_val = None + SetValueEx(ck, "test_name", 0, REG_BINARY, test_val) + ret_val, ret_type = QueryValueEx(ck, "test_name") + self.assertEqual(ret_type, REG_BINARY) + self.assertEqual(ret_val, test_val) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_read_string_containing_null(self): + # Test for issue 25778: REG_SZ should not contain null characters + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + test_val = "A string\x00 with a null" + SetValueEx(ck, "test_name", 0, REG_SZ, test_val) + ret_val, ret_type = QueryValueEx(ck, "test_name") + self.assertEqual(ret_type, REG_SZ) + self.assertEqual(ret_val, "A string") + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + +@unittest.skipUnless(REMOTE_NAME, "Skipping remote registry tests") +class RemoteWinregTests(BaseWinregTests): + + def test_remote_registry_works(self): + remote_key = ConnectRegistry(REMOTE_NAME, HKEY_CURRENT_USER) + self._test_all(remote_key) + + +@unittest.skipUnless(WIN64_MACHINE, "x64 specific registry tests") +class Win64WinregTests(BaseWinregTests): + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_named_arguments(self): + self._test_named_args(HKEY_CURRENT_USER, test_key_name) + # Clean up and also exercise the named arguments + DeleteKeyEx(key=HKEY_CURRENT_USER, sub_key=test_key_name, + access=KEY_ALL_ACCESS, reserved=0) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + @unittest.skipIf(win32_edition() in ('WindowsCoreHeadless', 'IoTEdgeOS'), "APIs not available on WindowsCoreHeadless") + def test_reflection_functions(self): + # Test that we can call the query, enable, and disable functions + # on a key which isn't on the reflection list with no consequences. + with OpenKey(HKEY_LOCAL_MACHINE, "Software") as key: + # HKLM\Software is redirected but not reflected in all OSes + self.assertTrue(QueryReflectionKey(key)) + self.assertIsNone(EnableReflectionKey(key)) + self.assertIsNone(DisableReflectionKey(key)) + self.assertTrue(QueryReflectionKey(key)) + + @unittest.skipUnless(HAS_REFLECTION, "OS doesn't support reflection") + def test_reflection(self): + # Test that we can create, open, and delete keys in the 32-bit + # area. Because we are doing this in a key which gets reflected, + # test the differences of 32 and 64-bit keys before and after the + # reflection occurs (ie. when the created key is closed). + try: + with CreateKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_32KEY) as created_key: + self.assertNotEqual(created_key.handle, 0) + + # The key should now be available in the 32-bit area + with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_32KEY) as key: + self.assertNotEqual(key.handle, 0) + + # Write a value to what currently is only in the 32-bit area + SetValueEx(created_key, "", 0, REG_SZ, "32KEY") + + # The key is not reflected until created_key is closed. + # The 64-bit version of the key should not be available yet. + open_fail = lambda: OpenKey(HKEY_CURRENT_USER, + test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_64KEY) + self.assertRaises(OSError, open_fail) + + # Now explicitly open the 64-bit version of the key + with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_64KEY) as key: + self.assertNotEqual(key.handle, 0) + # Make sure the original value we set is there + self.assertEqual("32KEY", QueryValue(key, "")) + # Set a new value, which will get reflected to 32-bit + SetValueEx(key, "", 0, REG_SZ, "64KEY") + + # Reflection uses a "last-writer wins policy, so the value we set + # on the 64-bit key should be the same on 32-bit + with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_32KEY) as key: + self.assertEqual("64KEY", QueryValue(key, "")) + finally: + DeleteKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, + KEY_WOW64_32KEY, 0) + + @unittest.skipUnless(HAS_REFLECTION, "OS doesn't support reflection") + def test_disable_reflection(self): + # Make use of a key which gets redirected and reflected + try: + with CreateKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_32KEY) as created_key: + # QueryReflectionKey returns whether or not the key is disabled + disabled = QueryReflectionKey(created_key) + self.assertEqual(type(disabled), bool) + # HKCU\Software\Classes is reflected by default + self.assertFalse(disabled) + + DisableReflectionKey(created_key) + self.assertTrue(QueryReflectionKey(created_key)) + + # The key is now closed and would normally be reflected to the + # 64-bit area, but let's make sure that didn't happen. + open_fail = lambda: OpenKeyEx(HKEY_CURRENT_USER, + test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_64KEY) + self.assertRaises(OSError, open_fail) + + # Make sure the 32-bit key is actually there + with OpenKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_32KEY) as key: + self.assertNotEqual(key.handle, 0) + finally: + DeleteKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, + KEY_WOW64_32KEY, 0) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_exception_numbers(self): + with self.assertRaises(FileNotFoundError) as ctx: + QueryValue(HKEY_CLASSES_ROOT, 'some_value_that_does_not_exist') + + +if __name__ == "__main__": + if not REMOTE_NAME: + print("Remote registry calls can be tested using", + "'test_winreg.py --remote \\\\machine_name'") + unittest.main() diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 5a4b0df2a1..7125fe4a5f 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -117,7 +117,6 @@ num_cpus = "1.13.1" [target.'cfg(windows)'.dependencies] junction = { workspace = true } schannel = { workspace = true } -winreg = "0.55" [target.'cfg(windows)'.dependencies.windows] version = "0.52.0" diff --git a/vm/src/stdlib/winreg.rs b/vm/src/stdlib/winreg.rs index 8d1ca89ddd..39a269fd71 100644 --- a/vm/src/stdlib/winreg.rs +++ b/vm/src/stdlib/winreg.rs @@ -4,39 +4,31 @@ use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { - let module = winreg::make_module(vm); - - macro_rules! add_constants { - ($($name:ident),*$(,)?) => { - extend_module!(vm, &module, { - $((stringify!($name)) => vm.new_pyobj(::winreg::enums::$name as usize)),* - }) - }; - } - - add_constants!( - HKEY_CLASSES_ROOT, - HKEY_CURRENT_USER, - HKEY_LOCAL_MACHINE, - HKEY_USERS, - HKEY_PERFORMANCE_DATA, - HKEY_CURRENT_CONFIG, - HKEY_DYN_DATA, - ); - module + winreg::make_module(vm) } #[pymodule] mod winreg { - use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; - use crate::{ - PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::PyStrRef, - convert::ToPyException, - }; - use ::winreg::{RegKey, RegValue, enums::RegType}; - use std::mem::ManuallyDrop; - use std::{ffi::OsStr, io}; - use windows_sys::Win32::Foundation; + use std::ffi::OsStr; + use std::os::windows::ffi::OsStrExt; + use std::ptr; + use std::sync::Arc; + + use crate::builtins::{PyInt, PyTuple}; + use crate::common::lock::PyRwLock; + use crate::function::FuncArgs; + use crate::protocol::PyNumberMethods; + use crate::types::AsNumber; + use crate::{PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; + + use windows_sys::Win32::Foundation::{self, ERROR_MORE_DATA}; + use windows_sys::Win32::System::Registry; + + use num_traits::ToPrimitive; + + pub(crate) fn to_utf16>(s: P) -> Vec { + s.as_ref().encode_wide().chain(Some(0)).collect() + } // access rights #[pyattr] @@ -57,290 +49,884 @@ mod winreg { REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, REG_WHOLE_HIVE_VOLATILE, }; - #[pyattr] - #[pyclass(module = "winreg", name = "HKEYType")] - #[derive(Debug, PyPayload)] - struct PyHkey { - key: PyRwLock, + #[pyattr(once)] + fn HKEY_CLASSES_ROOT(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_CLASSES_ROOT)), + } } - type PyHkeyRef = PyRef; - // TODO: fix this - unsafe impl Sync for PyHkey {} + #[pyattr(once)] + fn HKEY_CURRENT_USER(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_CURRENT_USER)), + } + } - impl PyHkey { - fn new(key: RegKey) -> Self { - Self { - key: PyRwLock::new(key), - } + #[pyattr(once)] + fn HKEY_LOCAL_MACHINE(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_LOCAL_MACHINE)), } + } - fn key(&self) -> PyRwLockReadGuard<'_, RegKey> { - self.key.read() + #[pyattr(once)] + fn HKEY_USERS(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_USERS)), } + } - fn key_mut(&self) -> PyRwLockWriteGuard<'_, RegKey> { - self.key.write() + #[pyattr(once)] + fn HKEY_PERFORMANCE_DATA(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_PERFORMANCE_DATA)), } } - #[pyclass] - impl PyHkey { - #[pymethod] - fn Close(&self) { - let null_key = RegKey::predef(0 as ::winreg::HKEY); - let key = std::mem::replace(&mut *self.key_mut(), null_key); - drop(key); + #[pyattr(once)] + fn HKEY_CURRENT_CONFIG(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_CURRENT_CONFIG)), } - #[pymethod] - fn Detach(&self) -> usize { - let null_key = RegKey::predef(0 as ::winreg::HKEY); - let key = std::mem::replace(&mut *self.key_mut(), null_key); - let handle = key.raw_handle(); - std::mem::forget(key); - handle as usize + } + + #[pyattr(once)] + fn HKEY_DYN_DATA(_vm: &VirtualMachine) -> PyHKEYObject { + PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(Registry::HKEY_DYN_DATA)), + } + } + + #[pyattr] + #[pyclass(name)] + #[derive(Clone, Debug, PyPayload)] + pub struct PyHKEYObject { + hkey: Arc>, + } + + // TODO: Fix + unsafe impl Send for PyHKEYObject {} + unsafe impl Sync for PyHKEYObject {} + + #[pyclass(with(AsNumber))] + impl PyHKEYObject { + #[pygetset] + fn handle(&self) -> usize { + *self.hkey.read() as usize } #[pymethod(magic)] fn bool(&self) -> bool { - !self.key().raw_handle().is_null() + !self.hkey.read().is_null() } + #[pymethod(magic)] - fn enter(zelf: PyRef) -> PyRef { - zelf + fn int(&self) -> usize { + *self.hkey.read() as usize } + #[pymethod(magic)] - fn exit(&self, _cls: PyObjectRef, _exc: PyObjectRef, _tb: PyObjectRef) { - self.Close(); + fn str(&self) -> String { + format!("", *self.hkey.read() as usize) } - } - enum Hkey { - PyHkey(PyHkeyRef), - Constant(::winreg::HKEY), - } - impl TryFromObject for Hkey { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - obj.downcast().map(Self::PyHkey).or_else(|o| { - usize::try_from_object(vm, o).map(|i| Self::Constant(i as ::winreg::HKEY)) - }) + #[pymethod] + fn Close(&self, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegCloseKey(*self.hkey.write()) }; + *self.hkey.write() = std::ptr::null_mut(); + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error("msg TODO".to_string())) + } + } + + #[pymethod] + fn Detach(&self) -> PyResult { + let hkey = *self.hkey.write(); + // std::mem::forget(self); + // TODO: Fix this + Ok(hkey as usize) + } + + // fn AsHKEY(object: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + // if vm.is_none(object) { + // return Err(vm.new_type_error("cannot convert None to HKEY".to_owned())) + // } else if let Some(hkey) = object.downcast_ref::() { + // Ok(true) + // } else { + // Err(vm.new_type_error("The object is not a PyHKEY object".to_owned())) + // } + // } + + #[pymethod(magic)] + fn enter(zelf: PyRef, _vm: &VirtualMachine) -> PyResult> { + Ok(zelf) + } + + #[pymethod(magic)] + fn exit(zelf: PyRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegCloseKey(*zelf.hkey.write()) }; + *zelf.hkey.write() = std::ptr::null_mut(); + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error("msg TODO".to_string())) + } } } - impl Hkey { - fn with_key(&self, f: impl FnOnce(&RegKey) -> R) -> R { - match self { - Self::PyHkey(py) => f(&py.key()), - Self::Constant(hkey) => { - let k = ManuallyDrop::new(RegKey::predef(*hkey)); - f(&k) + + impl Drop for PyHKEYObject { + fn drop(&mut self) { + unsafe { + let hkey = *self.hkey.write(); + if !hkey.is_null() { + Registry::RegCloseKey(hkey); } } } - fn into_key(self) -> RegKey { - let k = match self { - Self::PyHkey(py) => py.key().raw_handle(), - Self::Constant(k) => k, + } + + pub const HKEY_ERR_MSG: &str = "bad operand type"; + + impl PyHKEYObject { + pub fn new(hkey: *mut std::ffi::c_void) -> Self { + Self { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(hkey)), + } + } + + pub fn unary_fail(vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(HKEY_ERR_MSG.to_owned())) + } + + pub fn binary_fail(vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(HKEY_ERR_MSG.to_owned())) + } + + pub fn ternary_fail(vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(HKEY_ERR_MSG.to_owned())) + } + } + + impl AsNumber for PyHKEYObject { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + add: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + subtract: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + multiply: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + remainder: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + divmod: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + power: Some(|_a, _b, _c, vm| PyHKEYObject::ternary_fail(vm)), + negative: Some(|_a, vm| PyHKEYObject::unary_fail(vm)), + positive: Some(|_a, vm| PyHKEYObject::unary_fail(vm)), + absolute: Some(|_a, vm| PyHKEYObject::unary_fail(vm)), + boolean: Some(|a, vm| { + if let Some(a) = a.downcast_ref::() { + Ok(a.bool()) + } else { + PyHKEYObject::unary_fail(vm)?; + unreachable!() + } + }), + invert: Some(|_a, vm| PyHKEYObject::unary_fail(vm)), + lshift: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + rshift: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + and: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + xor: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + or: Some(|_a, _b, vm| PyHKEYObject::binary_fail(vm)), + int: Some(|a, vm| { + if let Some(a) = a.downcast_ref::() { + Ok(vm.new_pyobj(a.int())) + } else { + PyHKEYObject::unary_fail(vm)?; + unreachable!() + } + }), + float: Some(|_a, vm| PyHKEYObject::unary_fail(vm)), + ..PyNumberMethods::NOT_IMPLEMENTED }; - RegKey::predef(k) + &AS_NUMBER } } - #[derive(FromArgs)] - struct OpenKeyArgs { - key: Hkey, - sub_key: Option, + // TODO: Computer name can be `None` + #[pyfunction] + fn ConnectRegistry( + computer_name: Option, + key: PyRef, + vm: &VirtualMachine, + ) -> PyResult { + if let Some(computer_name) = computer_name { + let mut ret_key = std::ptr::null_mut(); + let wide_computer_name = to_utf16(computer_name); + let res = unsafe { + Registry::RegConnectRegistryW( + wide_computer_name.as_ptr(), + *key.hkey.read(), + &mut ret_key, + ) + }; + if res == 0 { + Ok(PyHKEYObject::new(ret_key)) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } else { + let mut ret_key = std::ptr::null_mut(); + let res = unsafe { + Registry::RegConnectRegistryW(std::ptr::null_mut(), *key.hkey.read(), &mut ret_key) + }; + if res == 0 { + Ok(PyHKEYObject::new(ret_key)) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } + } + + #[pyfunction] + fn CreateKey( + key: PyRef, + sub_key: String, + vm: &VirtualMachine, + ) -> PyResult { + let wide_sub_key = to_utf16(sub_key); + let mut out_key = std::ptr::null_mut(); + let res = unsafe { + Registry::RegCreateKeyW(*key.hkey.read(), wide_sub_key.as_ptr(), &mut out_key) + }; + if res == 0 { + Ok(PyHKEYObject::new(out_key)) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } + + #[derive(FromArgs, Debug)] + struct CreateKeyExArgs { + #[pyarg(any)] + key: PyRef, + #[pyarg(any)] + sub_key: String, #[pyarg(any, default = 0)] - reserved: i32, - #[pyarg(any, default = ::winreg::enums::KEY_READ)] + reserved: u32, + #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_WRITE)] access: u32, } - #[pyfunction(name = "OpenKeyEx")] #[pyfunction] - fn OpenKey(args: OpenKeyArgs, vm: &VirtualMachine) -> PyResult { - let OpenKeyArgs { - key, - sub_key, - reserved, - access, - } = args; + fn CreateKeyEx(args: CreateKeyExArgs, vm: &VirtualMachine) -> PyResult { + let wide_sub_key = to_utf16(args.sub_key); + let mut res: *mut std::ffi::c_void = core::ptr::null_mut(); + let err = unsafe { + let key = *args.key.hkey.read(); + Registry::RegCreateKeyExW( + key, + wide_sub_key.as_ptr(), + args.reserved, + std::ptr::null(), + Registry::REG_OPTION_NON_VOLATILE, + args.access, + std::ptr::null(), + &mut res, + std::ptr::null_mut(), + ) + }; + if err == 0 { + Ok(PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(res)), + }) + } else { + Err(vm.new_os_error(format!("error code: {}", err))) + } + } + + #[pyfunction] + fn DeleteKey(key: PyRef, sub_key: String, vm: &VirtualMachine) -> PyResult<()> { + let wide_sub_key = to_utf16(sub_key); + let res = unsafe { Registry::RegDeleteKeyW(*key.hkey.read(), wide_sub_key.as_ptr()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } + + #[derive(FromArgs, Debug)] + struct DeleteKeyExArgs { + #[pyarg(any)] + key: PyRef, + #[pyarg(any)] + sub_key: String, + #[pyarg(any, default = 0)] + reserved: u32, + #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_WOW64_32KEY)] + access: u32, + } + + #[pyfunction] + fn DeleteKeyEx(args: DeleteKeyExArgs, vm: &VirtualMachine) -> PyResult<()> { + let wide_sub_key = to_utf16(args.sub_key); + let res = unsafe { + Registry::RegDeleteKeyExW( + *args.key.hkey.read(), + wide_sub_key.as_ptr(), + args.reserved, + args.access, + ) + }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } + + // #[pyfunction] + // fn EnumKey(key: PyRef, index: i32, vm: &VirtualMachine) -> PyResult { + // let mut tmpbuf = [0u16; 257]; + // let mut len = std::mem::sizeof(tmpbuf.len())/std::mem::sizeof(tmpbuf[0]); + // let res = unsafe { + // Registry::RegEnumKeyExW( + // *key.hkey.read(), + // index as u32, + // tmpbuf.as_mut_ptr(), + // &mut len, + // std::ptr::null_mut(), + // std::ptr::null_mut(), + // std::ptr::null_mut(), + // std::ptr::null_mut(), + // ) + // }; + // if res != 0 { + // return Err(vm.new_os_error(format!("error code: {}", res))); + // } + // let s = String::from_utf16(&tmpbuf[..len as usize]) + // .map_err(|e| vm.new_value_error(format!("UTF16 error: {}", e)))?; + // Ok(s) + // } - if reserved != 0 { - // RegKey::open_subkey* doesn't have a reserved param, so this'll do - return Err(vm.new_value_error("reserved param must be 0".to_owned())); + #[pyfunction] + fn EnumValue(hkey: PyRef, index: u32, vm: &VirtualMachine) -> PyResult { + // Query registry for the required buffer sizes. + let mut ret_value_size: u32 = 0; + let mut ret_data_size: u32 = 0; + let hkey: *mut std::ffi::c_void = *hkey.hkey.read(); + let rc = unsafe { + Registry::RegQueryInfoKeyW( + hkey, + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + &mut ret_value_size as *mut u32, + &mut ret_data_size as *mut u32, + ptr::null_mut(), + ptr::null_mut(), + ) + }; + if rc != 0 { + return Err(vm.new_os_error(format!("RegQueryInfoKeyW failed with error code {}", rc))); } - let sub_key = sub_key.as_ref().map_or("", |s| s.as_str()); - let key = key - .with_key(|k| k.open_subkey_with_flags(sub_key, access)) - .map_err(|e| e.to_pyexception(vm))?; + // Include room for null terminators. + ret_value_size += 1; + ret_data_size += 1; + let mut buf_value_size = ret_value_size; + let mut buf_data_size = ret_data_size; + + // Allocate buffers. + let mut ret_value_buf: Vec = vec![0; ret_value_size as usize]; + let mut ret_data_buf: Vec = vec![0; ret_data_size as usize]; + + // Loop to enumerate the registry value. + loop { + let mut current_value_size = ret_value_size; + let mut current_data_size = ret_data_size; + let rc = unsafe { + Registry::RegEnumValueW( + hkey, + index, + ret_value_buf.as_mut_ptr(), + &mut current_value_size as *mut u32, + ptr::null_mut(), + { + // typ will hold the registry data type. + let mut t = 0u32; + &mut t + }, + ret_data_buf.as_mut_ptr(), + &mut current_data_size as *mut u32, + ) + }; + if rc == ERROR_MORE_DATA { + // Double the buffer sizes. + buf_data_size *= 2; + buf_value_size *= 2; + ret_data_buf.resize(buf_data_size as usize, 0); + ret_value_buf.resize(buf_value_size as usize, 0); + // Reset sizes for next iteration. + ret_value_size = buf_value_size; + ret_data_size = buf_data_size; + continue; + } + if rc != 0 { + return Err(vm.new_os_error(format!("RegEnumValueW failed with error code {}", rc))); + } - Ok(PyHkey::new(key)) + // At this point, current_value_size and current_data_size have been updated. + // Retrieve the registry type. + let mut reg_type: u32 = 0; + unsafe { + Registry::RegEnumValueW( + hkey, + index, + ret_value_buf.as_mut_ptr(), + &mut current_value_size as *mut u32, + ptr::null_mut(), + &mut reg_type as *mut u32, + ret_data_buf.as_mut_ptr(), + &mut current_data_size as *mut u32, + ) + }; + + // Convert the registry value name from UTF‑16. + let name_len = ret_value_buf + .iter() + .position(|&c| c == 0) + .unwrap_or(ret_value_buf.len()); + let name = String::from_utf16(&ret_value_buf[..name_len]) + .map_err(|e| vm.new_value_error(format!("UTF16 conversion error: {}", e)))?; + + // Slice the data buffer to the actual size returned. + let data_slice = &ret_data_buf[..current_data_size as usize]; + let py_data = reg_to_py(vm, data_slice, reg_type)?; + + // Return tuple (value_name, data, type) + return Ok(vm + .ctx + .new_tuple(vec![ + vm.ctx.new_str(name).into(), + py_data, + vm.ctx.new_int(reg_type).into(), + ]) + .into()); + } } #[pyfunction] - fn QueryValue(key: Hkey, subkey: Option, vm: &VirtualMachine) -> PyResult { - let subkey = subkey.as_ref().map_or("", |s| s.as_str()); - key.with_key(|k| k.get_value(subkey)) - .map_err(|e| e.to_pyexception(vm)) + fn FlushKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegFlushKey(*key.hkey.read()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } } #[pyfunction] - fn QueryValueEx( - key: Hkey, - subkey: Option, + fn LoadKey( + key: PyRef, + sub_key: String, + file_name: String, vm: &VirtualMachine, - ) -> PyResult<(PyObjectRef, usize)> { - let subkey = subkey.as_ref().map_or("", |s| s.as_str()); - let regval = key - .with_key(|k| k.get_raw_value(subkey)) - .map_err(|e| e.to_pyexception(vm))?; - #[allow(clippy::redundant_clone)] - let ty = regval.vtype.clone() as usize; - Ok((reg_to_py(regval, vm)?, ty)) + ) -> PyResult<()> { + let sub_key = to_utf16(sub_key); + let file_name = to_utf16(file_name); + let res = unsafe { + Registry::RegLoadKeyW(*key.hkey.read(), sub_key.as_ptr(), file_name.as_ptr()) + }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } + + #[derive(Debug, FromArgs)] + struct OpenKeyArgs { + #[pyarg(any)] + key: PyRef, + #[pyarg(any)] + sub_key: String, + #[pyarg(any, default = 0)] + reserved: u32, + #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_READ)] + access: u32, } #[pyfunction] - fn EnumKey(key: Hkey, index: u32, vm: &VirtualMachine) -> PyResult { - key.with_key(|k| k.enum_keys().nth(index as usize)) - .unwrap_or_else(|| { - Err(io::Error::from_raw_os_error( - Foundation::ERROR_NO_MORE_ITEMS as i32, - )) - }) - .map_err(|e| e.to_pyexception(vm)) + #[pyfunction(name = "OpenKeyEx")] + fn OpenKey(args: OpenKeyArgs, vm: &VirtualMachine) -> PyResult { + let wide_sub_key = to_utf16(args.sub_key); + let res: *mut *mut std::ffi::c_void = core::ptr::null_mut(); + let err = unsafe { + let key = *args.key.hkey.read(); + Registry::RegOpenKeyExW(key, wide_sub_key.as_ptr(), args.reserved, args.access, res) + }; + if err == 0 { + unsafe { + Ok(PyHKEYObject { + #[allow(clippy::arc_with_non_send_sync)] + hkey: Arc::new(PyRwLock::new(*res)), + }) + } + } else { + Err(vm.new_os_error(format!("error code: {}", err))) + } } #[pyfunction] - fn EnumValue( - key: Hkey, - index: u32, - vm: &VirtualMachine, - ) -> PyResult<(String, PyObjectRef, usize)> { - let (name, value) = key - .with_key(|k| k.enum_values().nth(index as usize)) - .unwrap_or_else(|| { - Err(io::Error::from_raw_os_error( - Foundation::ERROR_NO_MORE_ITEMS as i32, - )) - }) - .map_err(|e| e.to_pyexception(vm))?; - #[allow(clippy::redundant_clone)] - let ty = value.vtype.clone() as usize; - Ok((name, reg_to_py(value, vm)?, ty)) + fn QueryInfoKey(key: PyRef, vm: &VirtualMachine) -> PyResult> { + let key = *key.hkey.read(); + let mut lpcsubkeys: u32 = 0; + let mut lpcvalues: u32 = 0; + let mut lpftlastwritetime: Foundation::FILETIME = unsafe { std::mem::zeroed() }; + let err = unsafe { + Registry::RegQueryInfoKeyW( + key, + std::ptr::null_mut(), + std::ptr::null_mut(), + 0 as _, + &mut lpcsubkeys, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut lpcvalues, + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut lpftlastwritetime, + ) + }; + + if err != 0 { + return Err(vm.new_os_error(format!("error code: {}", err))); + } + let l: u64 = (lpftlastwritetime.dwHighDateTime as u64) << 32 + | lpftlastwritetime.dwLowDateTime as u64; + let tup: Vec = vec![ + vm.ctx.new_int(lpcsubkeys).into(), + vm.ctx.new_int(lpcvalues).into(), + vm.ctx.new_int(l).into(), + ]; + Ok(vm.ctx.new_tuple(tup)) } #[pyfunction] - fn CloseKey(key: Hkey) { - match key { - Hkey::PyHkey(py) => py.Close(), - Hkey::Constant(hkey) => drop(RegKey::predef(hkey)), + fn QueryValue(key: PyRef, sub_key: String, vm: &VirtualMachine) -> PyResult<()> { + let key = *key.hkey.read(); + let mut lpcbdata: i32 = 0; + // let mut lpdata = 0; + let wide_sub_key = to_utf16(sub_key); + let err = unsafe { + Registry::RegQueryValueW( + key, + wide_sub_key.as_ptr(), + std::ptr::null_mut(), + &mut lpcbdata, + ) + }; + + if err != 0 { + return Err(vm.new_os_error(format!("error code: {}", err))); } + + Ok(()) } #[pyfunction] - fn CreateKey(key: Hkey, subkey: Option, vm: &VirtualMachine) -> PyResult { - let k = match subkey { - Some(subkey) => { - let (k, _disp) = key - .with_key(|k| k.create_subkey(subkey.as_str())) - .map_err(|e| e.to_pyexception(vm))?; - k - } - None => key.into_key(), + fn QueryValueEx( + key: PyRef, + name: String, + vm: &VirtualMachine, + ) -> PyResult { + let wide_name = to_utf16(name); + let mut buf_size = 0; + let res = unsafe { + Registry::RegQueryValueExW( + *key.hkey.read(), + wide_name.as_ptr(), + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut buf_size, + ) }; - Ok(PyHkey::new(k)) + // TODO: res == ERROR_MORE_DATA + if res != 0 { + return Err(vm.new_os_error(format!("error code: {}", res))); + } + let mut retBuf = Vec::with_capacity(buf_size as usize); + let mut typ = 0; + let res = unsafe { + Registry::RegQueryValueExW( + *key.hkey.read(), + wide_name.as_ptr(), + std::ptr::null_mut(), + &mut typ, + retBuf.as_mut_ptr(), + &mut buf_size, + ) + }; + // TODO: res == ERROR_MORE_DATA + if res != 0 { + return Err(vm.new_os_error(format!("error code: {}", res))); + } + let obj = reg_to_py(vm, retBuf.as_slice(), typ)?; + Ok(obj) + } + + #[pyfunction] + fn SaveKey(key: PyRef, file_name: String, vm: &VirtualMachine) -> PyResult<()> { + let file_name = to_utf16(file_name); + let res = unsafe { + Registry::RegSaveKeyW(*key.hkey.read(), file_name.as_ptr(), std::ptr::null_mut()) + }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } } #[pyfunction] fn SetValue( - key: Hkey, - subkey: Option, + key: PyRef, + sub_key: String, typ: u32, - value: PyStrRef, + value: String, vm: &VirtualMachine, ) -> PyResult<()> { - if typ != REG_SZ { - return Err(vm.new_type_error("type must be winreg.REG_SZ".to_owned())); + if typ != Registry::REG_SZ { + return Err(vm.new_type_error("type must be winreg.REG_SZ".to_string())); } - let subkey = subkey.as_ref().map_or("", |s| s.as_str()); - key.with_key(|k| k.set_value(subkey, &OsStr::new(value.as_str()))) - .map_err(|e| e.to_pyexception(vm)) - } - #[pyfunction] - fn DeleteKey(key: Hkey, subkey: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { - key.with_key(|k| k.delete_subkey(subkey.as_str())) - .map_err(|e| e.to_pyexception(vm)) + let wide_sub_key = to_utf16(sub_key); + + // TODO: Value check + if *key.hkey.read() == Registry::HKEY_PERFORMANCE_DATA { + return Err(vm.new_os_error("Cannot set value on HKEY_PERFORMANCE_DATA".to_string())); + } + + // if (sub_key && sub_key[0]) { + // // TODO: create key + // } + + let res = unsafe { + Registry::RegSetValueExW( + *key.hkey.read(), + wide_sub_key.as_ptr(), + 0, + typ, + value.as_ptr(), + value.len() as u32, + ) + }; + + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } } - fn reg_to_py(value: RegValue, vm: &VirtualMachine) -> PyResult { - macro_rules! bytes_to_int { - ($int:ident, $f:ident, $name:ident) => {{ - let i = if value.bytes.is_empty() { - Ok(0 as $int) + fn reg_to_py(vm: &VirtualMachine, ret_data: &[u8], typ: u32) -> PyResult { + match typ { + REG_DWORD => { + // If there isn’t enough data, return 0. + if ret_data.len() < std::mem::size_of::() { + Ok(vm.ctx.new_int(0).into()) } else { - (&*value.bytes).try_into().map($int::$f).map_err(|_| { - vm.new_value_error(format!("{} value is wrong length", stringify!(name))) - }) - }; - i.map(|i| vm.ctx.new_int(i).into()) - }}; - } - let bytes_to_wide = |b| { - if <[u8]>::len(b) % 2 == 0 { - let (pref, wide, suf) = unsafe { <[u8]>::align_to::(b) }; - assert!(pref.is_empty() && suf.is_empty(), "wide slice is unaligned"); - Some(wide) - } else { - None + let val = u32::from_ne_bytes(ret_data[..4].try_into().unwrap()); + Ok(vm.ctx.new_int(val).into()) + } } - }; - match value.vtype { - RegType::REG_DWORD => bytes_to_int!(u32, from_ne_bytes, REG_DWORD), - RegType::REG_DWORD_BIG_ENDIAN => { - bytes_to_int!(u32, from_be_bytes, REG_DWORD_BIG_ENDIAN) + REG_QWORD => { + if ret_data.len() < std::mem::size_of::() { + Ok(vm.ctx.new_int(0).into()) + } else { + let val = u64::from_ne_bytes(ret_data[..8].try_into().unwrap()); + Ok(vm.ctx.new_int(val).into()) + } } - RegType::REG_QWORD => bytes_to_int!(u64, from_ne_bytes, REG_DWORD), - // RegType::REG_QWORD_BIG_ENDIAN => bytes_to_int!(u64, from_be_bytes, REG_DWORD_BIG_ENDIAN), - RegType::REG_SZ | RegType::REG_EXPAND_SZ => { - let wide_slice = bytes_to_wide(&value.bytes).ok_or_else(|| { - vm.new_value_error("REG_SZ string doesn't have an even byte length".to_owned()) - })?; - let nul_pos = wide_slice + REG_SZ | REG_EXPAND_SZ => { + // Treat the data as a UTF-16 string. + let u16_count = ret_data.len() / 2; + let u16_slice = unsafe { + std::slice::from_raw_parts(ret_data.as_ptr() as *const u16, u16_count) + }; + // Only use characters up to the first NUL. + let len = u16_slice .iter() - .position(|w| *w == 0) - .unwrap_or(wide_slice.len()); - let s = String::from_utf16_lossy(&wide_slice[..nul_pos]); + .position(|&c| c == 0) + .unwrap_or(u16_slice.len()); + let s = String::from_utf16(&u16_slice[..len]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {}", e)))?; Ok(vm.ctx.new_str(s).into()) } - RegType::REG_MULTI_SZ => { - if value.bytes.is_empty() { - return Ok(vm.ctx.new_list(vec![]).into()); - } - let wide_slice = bytes_to_wide(&value.bytes).ok_or_else(|| { - vm.new_value_error( - "REG_MULTI_SZ string doesn't have an even byte length".to_owned(), - ) - })?; - let wide_slice = if let Some((0, rest)) = wide_slice.split_last() { - rest + REG_MULTI_SZ => { + if ret_data.is_empty() { + Ok(vm.ctx.new_list(vec![]).into()) } else { - wide_slice - }; - let strings = wide_slice - .split(|c| *c == 0) - .map(|s| vm.new_pyobj(String::from_utf16_lossy(s))) - .collect(); - Ok(vm.ctx.new_list(strings).into()) + let u16_count = ret_data.len() / 2; + let u16_slice = unsafe { + std::slice::from_raw_parts(ret_data.as_ptr() as *const u16, u16_count) + }; + let mut strings: Vec = Vec::new(); + let mut start = 0; + for (i, &c) in u16_slice.iter().enumerate() { + if c == 0 { + // An empty string signals the end. + if start == i { + break; + } + let s = String::from_utf16(&u16_slice[start..i]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {}", e)))?; + strings.push(vm.ctx.new_str(s).into()); + start = i + 1; + } + } + Ok(vm.ctx.new_list(strings).into()) + } } + // For REG_BINARY and any other unknown types, return a bytes object if data exists. _ => { - if value.bytes.is_empty() { + if ret_data.is_empty() { Ok(vm.ctx.none()) } else { - Ok(vm.ctx.new_bytes(value.bytes).into()) + Ok(vm.ctx.new_bytes(ret_data.to_vec()).into()) } } } } + + fn py2reg(value: PyObjectRef, typ: u32, vm: &VirtualMachine) -> PyResult>> { + match typ { + REG_DWORD => { + let val = value.downcast_ref::(); + if val.is_none() { + return Err(vm.new_type_error("value must be an integer".to_string())); + } + let val = val.unwrap().as_bigint().to_u32().unwrap(); + Ok(Some(val.to_le_bytes().to_vec())) + } + REG_QWORD => { + let val = value.downcast_ref::(); + if val.is_none() { + return Err(vm.new_type_error("value must be an integer".to_string())); + } + let val = val.unwrap().as_bigint().to_u64().unwrap(); + Ok(Some(val.to_le_bytes().to_vec())) + } + // REG_SZ is fallthrough + REG_EXPAND_SZ => { + Err(vm + .new_type_error("TODO: RUSTPYTHON REG_EXPAND_SZ is not supported".to_string())) + } + REG_MULTI_SZ => { + Err(vm.new_type_error("TODO: RUSTPYTHON REG_MULTI_SZ is not supported".to_string())) + } + // REG_BINARY is fallthrough + _ => { + if vm.is_none(&value) { + return Ok(None); + } + Err(vm.new_type_error("TODO: RUSTPYTHON Not supported".to_string())) + } + } + } + + #[pyfunction] + fn SetValueEx( + key: PyRef, + value_name: String, + _reserved: u32, + typ: u32, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + match py2reg(value, typ, vm) { + Ok(Some(v)) => { + let len = v.len() as u32; + let ptr = v.as_ptr(); + let wide_value_name = to_utf16(value_name); + let res = unsafe { + Registry::RegSetValueExW( + *key.hkey.read(), + wide_value_name.as_ptr(), + 0, + typ, + ptr, + len, + ) + }; + if res != 0 { + return Err(vm.new_os_error(format!("error code: {}", res))); + } + } + Ok(None) => { + let len = 0; + let ptr = std::ptr::null(); + let wide_value_name = to_utf16(value_name); + let res = unsafe { + Registry::RegSetValueExW( + *key.hkey.read(), + wide_value_name.as_ptr(), + 0, + typ, + ptr, + len, + ) + }; + if res != 0 { + return Err(vm.new_os_error(format!("error code: {}", res))); + } + } + Err(_) => return Err(vm.new_type_error("value must be an integer".to_string())), + } + Ok(()) + } + + #[pyfunction] + fn EnableReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegEnableReflectionKey(*key.hkey.read()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {}", res))) + } + } + + #[pyfunction] + fn ExpandEnvironmentStrings(i: String) -> PyResult { + let mut out = vec![0; 1024]; + let r = unsafe { + windows_sys::Win32::System::Environment::ExpandEnvironmentStringsA( + i.as_ptr(), + out.as_mut_ptr(), + out.len() as u32, + ) + }; + let s = String::from_utf8(out[..r as usize].to_vec()) + .unwrap() + .replace("\0", "") + .replace("\x02", "") + .to_string(); + + Ok(s) + } }