diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index 7841b99a5cd4c1..475814eeb93fd7 100644 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -287,8 +287,11 @@ def refactor_dir(self, dir_name, write=False, doctests_only=False): Python files are assumed to have a .py extension. Files and subdirectories starting with '.' are skipped. + + Returns a list of changed files. """ py_ext = os.extsep + "py" + changed = [] for dirpath, dirnames, filenames in os.walk(dir_name): self.log_debug("Descending into %s", dirpath) dirnames.sort() @@ -297,9 +300,11 @@ def refactor_dir(self, dir_name, write=False, doctests_only=False): if (not name.startswith(".") and os.path.splitext(name)[1] == py_ext): fullname = os.path.join(dirpath, name) - self.refactor_file(fullname, write, doctests_only) + if self.refactor_file(fullname, write, doctests_only): + changed.append(fullname) # Modify dirnames in-place to remove subdirs with leading dots dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] + return changed def _read_python_source(self, filename): """ @@ -318,27 +323,37 @@ def _read_python_source(self, filename): return f.read(), encoding def refactor_file(self, filename, write=False, doctests_only=False): - """Refactors a file.""" + """Refactors a file. + + Returns: + True: if the file was modified. + False: if the file was not modified. + None: if the file could not be read. + """ input, encoding = self._read_python_source(filename) if input is None: # Reading the file failed. - return + return None input += "\n" # Silence certain parse errors if doctests_only: self.log_debug("Refactoring doctests in %s", filename) output = self.refactor_docstring(input, filename) if self.write_unchanged_files or output != input: self.processed_file(output, filename, input, write, encoding) + return True else: self.log_debug("No doctest changes in %s", filename) + return False else: tree = self.refactor_string(input, filename) if self.write_unchanged_files or (tree and tree.was_changed): # The [:-1] is to take off the \n we added earlier self.processed_file(str(tree)[:-1], filename, write=write, encoding=encoding) + return True else: self.log_debug("No changes in %s", filename) + return False def refactor_string(self, data, name): """Refactor a given input string. diff --git a/Lib/lib2to3/tests/test_refactor.py b/Lib/lib2to3/tests/test_refactor.py index 9e3b8fbb90b2f3..35b19658b97f2f 100644 --- a/Lib/lib2to3/tests/test_refactor.py +++ b/Lib/lib2to3/tests/test_refactor.py @@ -179,7 +179,7 @@ def print_output(self, old_text, new_text, filename, equal): def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS, options=None, mock_log_debug=None, - actually_write=True): + actually_write=True, expected_return=True): test_file = self.init_test_file(test_file) old_contents = self.read_file(test_file) rt = self.rt(fixers=fixers, options=options) @@ -191,9 +191,13 @@ def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS, if not actually_write: return - rt.refactor_file(test_file, True) + ret = rt.refactor_file(test_file, True) new_contents = self.read_file(test_file) - self.assertNotEqual(old_contents, new_contents) + write_unchanged = options and options.get( + "write_unchanged_files", False) + if expected_return and not write_unchanged: + self.assertNotEqual(old_contents, new_contents) + self.assertEqual(ret, expected_return) return new_contents def init_test_file(self, test_file): @@ -220,6 +224,14 @@ def test_refactor_file(self): test_file = os.path.join(FIXER_DIR, "parrot_example.py") self.check_file_refactoring(test_file, _DEFAULT_FIXERS) + def test_refactor_file_return_false(self): + test_file = os.path.join(FIXER_DIR, "parrot_example.py") + self.check_file_refactoring(test_file, fixers=(), expected_return=False) + + def test_refactor_file_return_true(self): + test_file = os.path.join(TEST_DATA_DIR, "different_encoding.py") + self.check_file_refactoring(test_file, expected_return=True) + def test_refactor_file_write_unchanged_file(self): test_file = os.path.join(FIXER_DIR, "parrot_example.py") debug_messages = [] @@ -241,9 +253,11 @@ def recording_log_debug(msg, *args): self.fail("%r not matched in %r" % (message_regex, debug_messages)) def test_refactor_dir(self): - def check(structure, expected): + files_to_refactor = {"stuff.py"} + def check(structure, expected, expected_return): def mock_refactor_file(self, f, *args): got.append(f) + return os.path.basename(f) in files_to_refactor save_func = refactor.RefactoringTool.refactor_file refactor.RefactoringTool.refactor_file = mock_refactor_file rt = self.rt() @@ -253,13 +267,17 @@ def mock_refactor_file(self, f, *args): os.mkdir(os.path.join(dir, "a_dir")) for fn in structure: open(os.path.join(dir, fn), "wb").close() - rt.refactor_dir(dir) + ret =rt.refactor_dir(dir) + # Just check the basenames since we are working with tempdirs + self.assertEqual( + list(map(os.path.basename, ret)), + expected_return) finally: refactor.RefactoringTool.refactor_file = save_func shutil.rmtree(dir) self.assertEqual(got, [os.path.join(dir, path) for path in expected]) - check([], []) + check([], [], []) tree = ["nothing", "hi.py", ".dumb", @@ -267,10 +285,11 @@ def mock_refactor_file(self, f, *args): "notpy.npy", "sappy"] expected = ["hi.py"] - check(tree, expected) + check(tree, expected, []) tree = ["hi.py", os.path.join("a_dir", "stuff.py")] - check(tree, tree) + # stuff.py returns True from refactor_file and so is in expected_return + check(tree, tree, ["stuff.py"]) def test_file_encoding(self): fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") diff --git a/Misc/NEWS.d/next/Library/2018-11-20-14-31-36.bpo-35282.uimIBv.rst b/Misc/NEWS.d/next/Library/2018-11-20-14-31-36.bpo-35282.uimIBv.rst new file mode 100644 index 00000000000000..4888112cb76fe7 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-11-20-14-31-36.bpo-35282.uimIBv.rst @@ -0,0 +1,10 @@ +Added return values to lib2to3.refactor functions. + +:func:`lib2to3.refactor.refactor_file` now returns a status: + +* True: The file was written to. +* False: The file was not written to. +* None: The file could not be read. + +:func:`lib2to3.refactor.refactor_dir` now returns a list of all written +files.