From 6119be4d1213ac852dc0583aac38c7cbdec4b3a2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 26 Jun 2024 23:23:05 +0200 Subject: [PATCH 01/29] ENH fetch_file to fetch data files by URL with retries, checksuming and local caching --- sklearn/datasets/__init__.py | 2 + sklearn/datasets/_base.py | 200 ++++++++++++++++++++++++---- sklearn/datasets/tests/test_base.py | 35 +++++ 3 files changed, 211 insertions(+), 26 deletions(-) diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py index 58cddb099faff..f41b8fa7e73fb 100644 --- a/sklearn/datasets/__init__.py +++ b/sklearn/datasets/__init__.py @@ -4,6 +4,7 @@ from ._base import ( clear_data_home, + fetch_file, get_data_home, load_breast_cancer, load_diabetes, @@ -57,6 +58,7 @@ "dump_svmlight_file", "fetch_20newsgroups", "fetch_20newsgroups_vectorized", + "fetch_file", "fetch_lfw_pairs", "fetch_lfw_people", "fetch_olivetti_faces", diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 7dd2f181dee12..86283e05fa966 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -8,8 +8,10 @@ import gzip import hashlib import os +import re import shutil import time +import unicodedata import warnings from collections import namedtuple from importlib import resources @@ -17,7 +19,9 @@ from os import environ, listdir, makedirs from os.path import expanduser, isdir, join, splitext from pathlib import Path +from tempfile import NamedTemporaryFile from urllib.error import URLError +from urllib.parse import urlparse from urllib.request import urlretrieve import numpy as np @@ -1427,20 +1431,26 @@ def _sha256(path): def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): - """Helper function to download a remote dataset into path + """Helper function to download a remote dataset. Fetch a dataset pointed by remote's url, save into path using remote's - filename and ensure its integrity based on the SHA256 Checksum of the + filename and ensure its integrity based on the SHA256 checksum of the downloaded file. + .. versionchanged:: 1.6 + + If the file already exists locally and the SHA256 checksums match, the + path to the local file is returned without re-downloading. + Parameters ---------- remote : RemoteFileMetadata Named tuple containing remote dataset meta information: url, filename - and checksum + and checksum. - dirname : str - Directory to save the file to. + dirname : str or Path, default=None + Directory to save the file to. If None, the current working directory + is used. n_retries : int, default=3 Number of retries when HTTP errors are encountered. @@ -1454,28 +1464,166 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): Returns ------- - file_path: str + file_path: Path Full path of the created file. """ + if dirname is None: + folder_path = Path(".") + else: + folder_path = Path(dirname) + + file_path = folder_path / remote.filename + + if file_path.exists(): + if remote.checksum is None: + return file_path + + checksum = _sha256(file_path) + if checksum == remote.checksum: + return file_path + else: + warnings.warn( + f"SHA256 checksum of existing local file at {str(file_path)} " + f"({checksum}) differs from expected ({remote.checksum}): " + f"re-downloading from {remote.url} ." + ) + + temp_file = NamedTemporaryFile( + prefix=remote.filename + ".part_", dir=folder_path, delete=False + ) + try: + temp_file_path = Path(temp_file.name) + while True: + try: + urlretrieve(remote.url, temp_file_path) + break + except (URLError, TimeoutError): + if n_retries == 0: + # If no more retries are left, re-raise the caught exception. + raise + warnings.warn(f"Retry downloading from url: {remote.url}") + n_retries -= 1 + time.sleep(delay) + + checksum = _sha256(temp_file_path) + if remote.checksum is not None and remote.checksum != checksum: + raise OSError( + f"{remote.filename} has an SHA256 checksum ({checksum}) " + f"differing from expected ({remote.checksum}), " + "file may be corrupted." + ) + except BaseException: + temp_file.close() + os.unlink(temp_file.name) + raise + + # The following renaming is atomic whenever temp_file_path and + # file_path are on the same filesystem. This should be the case most of + # the time, but we still use shutil.move instead of os.rename in case + # they are not. + shutil.move(temp_file_path, file_path) - file_path = remote.filename if dirname is None else join(dirname, remote.filename) - while True: - try: - urlretrieve(remote.url, file_path) - break - except (URLError, TimeoutError): - if n_retries == 0: - # If no more retries are left, re-raise the caught exception. - raise - warnings.warn(f"Retry downloading from url: {remote.url}") - n_retries -= 1 - time.sleep(delay) - - checksum = _sha256(file_path) - if remote.checksum != checksum: - raise OSError( - "{} has an SHA256 checksum ({}) " - "differing from expected ({}), " - "file may be corrupted.".format(file_path, checksum, remote.checksum) - ) return file_path + + +def _slugify(value, allow_unicode=False): + """Derive a name that is safe to use as filename from the given string. + + Adapted from + https://github.com/django/django/blob/master/django/utils/text.py + + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + + Note: this version keeps "." characters unchanged contrary to the django + version and replace other un-authorized characters by "_". + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = ( + unicodedata.normalize("NFKD", value) + .encode("ascii", "ignore") + .decode("ascii") + ) + value = re.sub(r"[^.\w\s-]", "_", value.lower()) + value = re.sub(r"_+", "_", value) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + +def _derive_folder_and_filename_from_url(url): + parsed_url = urlparse(url) + path = parsed_url.path + if not path: + path = "/" + + if "/" in path: + base_folder, filename = path.rsplit("/", 1) + + if not filename: + filename = "downloaded_file" + + base_folder = _slugify(base_folder) + if base_folder: + base_folder = "/" + base_folder + + return ( + _slugify(parsed_url.hostname) + base_folder, + _slugify(filename), + ) + + +def fetch_file( + url, folder=None, local_filename=None, sha256=None, n_retries=3, delay=1 +): + """Fetch a file from the web. + + If the file already exists locally and the SHA256 checksums match, the path + to the local file is returned without re-downloading. + + Parameters + ---------- + url : str + URL of the file to download. + + folder : str or Path, default=None + Directory to save the file to. If None, the file is downloaded in a + folder with a name derived from the URL host name and path under + scikit-learn data home folder. + + local_filename : str, default=None + Name of the file to save. If None, the filename is inferred from the + URL. + + sha256 : str, default=None + SHA256 checksum of the file. If None, no checksum is verified. + + n_retries : int, default=3 + Number of retries when HTTP errors are encountered. + + delay : int, default=1 + Number of seconds between retries. + + Returns + ------- + file_path : Path + Full path of the downloaded file. + """ + folder_from_url, filename_from_url = _derive_folder_and_filename_from_url(url) + + if local_filename is None: + local_filename = filename_from_url + + if folder is None: + folder = Path(get_data_home()) / folder_from_url + makedirs(folder, exist_ok=True) + + remote_metadata = RemoteFileMetadata( + filename=local_filename, url=url, checksum=sha256 + ) + return _fetch_remote( + remote_metadata, dirname=folder, n_retries=n_retries, delay=delay + ) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index b79f8c47c55c5..3ae8163396d71 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -28,6 +28,7 @@ ) from sklearn.datasets._base import ( RemoteFileMetadata, + _derive_folder_and_filename_from_url, _fetch_remote, load_csv_data, load_gzip_compressed_csv_data, @@ -391,3 +392,37 @@ def test_fetch_remote_raise_warnings_with_invalid_url(monkeypatch): for r in record: assert str(r.message) == f"Retry downloading from url: {url}" assert len(record) == 3 + + +def test_derive_folder_and_filename_from_url(): + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/file.tar.gz" + ) + assert folder == "example.com" + assert filename == "file.tar.gz" + + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/path/to/file.tar.gz" + ) + assert folder == "example.com/path_to" + assert filename == "file.tar.gz" + + folder, filename = _derive_folder_and_filename_from_url("https://example.com/") + assert folder == "example.com" + assert filename == "downloaded_file" + + folder, filename = _derive_folder_and_filename_from_url("https://example.com") + assert folder == "example.com" + assert filename == "downloaded_file" + + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/path/@to/data.json?param=value" + ) + assert folder == "example.com/path_to" + assert filename == "data.json" + + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/path/@to/data.json#anchor" + ) + assert folder == "example.com/path_to" + assert filename == "data.json" From a4c456d9142c2ad3c5d7de2124117da44339d6b7 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 09:37:13 +0200 Subject: [PATCH 02/29] TST add tests for fetch_file, with and without SHA256 checks --- sklearn/datasets/_base.py | 5 +- sklearn/datasets/tests/test_base.py | 141 ++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 3 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 86283e05fa966..900cb2e06a2f7 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1508,9 +1508,8 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): checksum = _sha256(temp_file_path) if remote.checksum is not None and remote.checksum != checksum: raise OSError( - f"{remote.filename} has an SHA256 checksum ({checksum}) " - f"differing from expected ({remote.checksum}), " - "file may be corrupted." + f"The SHA256 checksum of {remote.filename} ({checksum}) " + f"differs from expected ({remote.checksum})." ) except BaseException: temp_file.close() diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 3ae8163396d71..bdb4e5e61757a 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -1,5 +1,7 @@ +import hashlib import io import os +import re import shutil import tempfile import warnings @@ -9,12 +11,14 @@ from pickle import dumps, loads from unittest.mock import Mock from urllib.error import HTTPError +from urllib.parse import urlparse import numpy as np import pytest from sklearn.datasets import ( clear_data_home, + fetch_file, get_data_home, load_breast_cancer, load_diabetes, @@ -426,3 +430,140 @@ def test_derive_folder_and_filename_from_url(): ) assert folder == "example.com/path_to" assert filename == "data.json" + + +def _mock_urlretrieve(server_side): + def _urlretrieve_mock(url, local_path): + file_path = urlparse(url).path + if not server_side.join(file_path).check(): + raise HTTPError(url, 404, "Not Found", None, None) + shutil.copy(server_side / file_path, local_path) + + return Mock(side_effect=_urlretrieve_mock) + + +def test_fetch_file_without_sha256(monkeypatch, tmpdir): + server_side = tmpdir.mkdir("server_side") + data_file = Path(server_side / "data.jsonl") + server_data = '{"a": 1, "b": 2}\n' + data_file.write_text(server_data, encoding="utf-8") + + client_side = tmpdir.mkdir("client_side") + + urlretrieve_mock = _mock_urlretrieve(server_side) + monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) + + # The first call should trigger a download: + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", + folder=client_side, + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 1 + + # Fetching again the same file to the same folder should do nothing: + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", + folder=client_side, + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 1 + + # Deleting and calling again should re-download + fetched_file_path.unlink() + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", + folder=client_side, + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 2 + + +def test_fetch_file_with_sha256(monkeypatch, tmpdir): + server_side = tmpdir.mkdir("server_side") + data_file = Path(server_side / "data.jsonl") + server_data = '{"a": 1, "b": 2}\n' + data_file.write_text(server_data, encoding="utf-8") + expected_sha256 = hashlib.sha256(data_file.read_bytes()).hexdigest() + + client_side = tmpdir.mkdir("client_side") + + urlretrieve_mock = _mock_urlretrieve(server_side) + monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) + + # The first call should trigger a download. + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 1 + + # Fetching again the same file to the same folder should do nothing when + # the sha256 match: + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 1 + + # Corrupting the local data should yield a warning and trigger a new download: + fetched_file_path.write_text("corruped contents", encoding="utf-8") + expected_msg = ( + r"SHA256 checksum of existing local file at .*client_side/data.jsonl " + rf"\(.*\) differs from expected \({expected_sha256}\): " + r"re-downloading from https://example.com/data.jsonl \." + ) + with pytest.warns(match=expected_msg): + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 2 + + # Calling again should do nothing: + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 2 + + # Deleting the local file and calling again should redownload without warning: + fetched_file_path.unlink() + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 3 + + # Calling without a sha256 should also work without redownloading: + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", + folder=client_side, + ) + assert fetched_file_path == client_side / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 3 + + # Calling with a wrong sha256 should raise an informative exception: + non_matching_sha256 = "deadbabecafebeef" + expected_msg = re.escape( + f"The SHA256 checksum of data.jsonl ({expected_sha256}) differs from " + f"expected ({non_matching_sha256})." + ) + with pytest.raises(OSError, match=expected_msg): + fetch_file( + "https://example.com/data.jsonl", + folder=client_side, + sha256=non_matching_sha256, + ) + # The local file should not have been deleted. + assert client_side.join("data.jsonl").read_text(encoding="utf-8") == server_data + assert urlretrieve_mock.call_count == 3 From d8bd1742ec43870da1fff73476f5b0d6a69640d0 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 09:39:35 +0200 Subject: [PATCH 03/29] Improve docstring --- sklearn/datasets/_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 900cb2e06a2f7..f3a5ed0d9f037 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1580,8 +1580,10 @@ def fetch_file( ): """Fetch a file from the web. - If the file already exists locally and the SHA256 checksums match, the path - to the local file is returned without re-downloading. + If the file already exists locally (and the SHA256 checksums match when + provided), the path to the local file is returned without re-downloading. + + .. versionadded:: 1.6 Parameters ---------- From ff808d3c8caf0c51be65789d5bc823fd7c4e2153 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 10:21:38 +0200 Subject: [PATCH 04/29] Test fetch_file's use of get_data_home --- sklearn/datasets/tests/test_base.py | 75 +++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index bdb4e5e61757a..6157589d052ae 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -434,14 +434,67 @@ def test_derive_folder_and_filename_from_url(): def _mock_urlretrieve(server_side): def _urlretrieve_mock(url, local_path): - file_path = urlparse(url).path - if not server_side.join(file_path).check(): + server_root = Path(server_side) + file_path = urlparse(url).path.strip("/") + if not (server_root / file_path).exists(): raise HTTPError(url, 404, "Not Found", None, None) - shutil.copy(server_side / file_path, local_path) + shutil.copy(server_root / file_path, local_path) return Mock(side_effect=_urlretrieve_mock) +def test_fetch_file_using_data_home(monkeypatch, tmpdir): + tmpdir = Path(tmpdir) + server_side = tmpdir / "server_side" + server_side.mkdir() + data_file = server_side / "data.jsonl" + server_data = '{"a": 1, "b": 2}\n' + data_file.write_text(server_data, encoding="utf-8") + + server_subfolder = server_side / "subfolder" + server_subfolder.mkdir() + other_data_file = server_subfolder / "other_file.txt" + other_data_file.write_text("Some important text data.", encoding="utf-8") + + data_home = tmpdir / "data_home" + data_home.mkdir() + + urlretrieve_mock = _mock_urlretrieve(server_side) + monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) + + monkeypatch.setattr( + "sklearn.datasets._base.get_data_home", Mock(return_value=data_home) + ) + fetched_file_path = fetch_file( + "https://example.com/data.jsonl", + ) + assert fetched_file_path == data_home / "example.com" / "data.jsonl" + assert fetched_file_path.read_text(encoding="utf-8") == server_data + + fetched_file_path = fetch_file( + "https://example.com/subfolder/other_file.txt", + ) + assert ( + fetched_file_path == data_home / "example.com" / "subfolder" / "other_file.txt" + ) + assert fetched_file_path.read_text(encoding="utf-8") == other_data_file.read_text( + "utf-8" + ) + + expected_warning_msg = re.escape( + "Retry downloading from url: https://example.com/subfolder/invalid.txt" + ) + with pytest.raises(HTTPError): + with pytest.warns(match=expected_warning_msg): + fetch_file( + "https://example.com/subfolder/invalid.txt", + delay=0, + ) + + local_subfolder = data_home / "example.com" / "subfolder" + assert sorted(local_subfolder.iterdir()) == [local_subfolder / "other_file.txt"] + + def test_fetch_file_without_sha256(monkeypatch, tmpdir): server_side = tmpdir.mkdir("server_side") data_file = Path(server_side / "data.jsonl") @@ -554,16 +607,18 @@ def test_fetch_file_with_sha256(monkeypatch, tmpdir): # Calling with a wrong sha256 should raise an informative exception: non_matching_sha256 = "deadbabecafebeef" - expected_msg = re.escape( + expected_warning_msg = "differs from expected" + expected_error_msg = re.escape( f"The SHA256 checksum of data.jsonl ({expected_sha256}) differs from " f"expected ({non_matching_sha256})." ) - with pytest.raises(OSError, match=expected_msg): - fetch_file( - "https://example.com/data.jsonl", - folder=client_side, - sha256=non_matching_sha256, - ) + with pytest.raises(OSError, match=expected_error_msg): + with pytest.warns(match=expected_warning_msg): + fetch_file( + "https://example.com/data.jsonl", + folder=client_side, + sha256=non_matching_sha256, + ) # The local file should not have been deleted. assert client_side.join("data.jsonl").read_text(encoding="utf-8") == server_data assert urlretrieve_mock.call_count == 3 From 219a077a31a3bfc4300780baf9bd09aa9507fcb1 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 10:31:11 +0200 Subject: [PATCH 05/29] Add changelog entry --- doc/whats_new/v1.6.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 1a6337f3ad746..0501bd2a44e7c 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -114,6 +114,15 @@ Changelog on the input data. :pr:`29124` by :user:`Yao Xiao `. + +:mod:`sklearn.datasets` +....................... + +- |Feature| :func:`datasets.fetch_file` allows downloading arbitrary data-file + from the web. It handles local caching, integrity checks with SHA256 digests + and automatic retries in case of HTTP errors. :pr:`Olivier Grisel ` + by :user:`Olivier Grisel `. + :mod:`sklearn.discriminant_analysis` .................................... From b969c85d630b56d0828cbf91f2dbe20126267f9f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 11:03:25 +0200 Subject: [PATCH 06/29] Fix PR number in changelog entry... --- doc/whats_new/v1.6.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 0501bd2a44e7c..be298cabfa8dd 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -120,8 +120,8 @@ Changelog - |Feature| :func:`datasets.fetch_file` allows downloading arbitrary data-file from the web. It handles local caching, integrity checks with SHA256 digests - and automatic retries in case of HTTP errors. :pr:`Olivier Grisel ` - by :user:`Olivier Grisel `. + and automatic retries in case of HTTP errors. :pr:`29354` by :user:`Olivier + Grisel `. :mod:`sklearn.discriminant_analysis` .................................... From b6900a10866b8bd1ef4186c5618fbf6b7f5be50c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 11:51:19 +0200 Subject: [PATCH 07/29] Close the temp file earlier to make Windows happier? --- sklearn/datasets/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index f3a5ed0d9f037..29fd786a60c41 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1491,6 +1491,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): temp_file = NamedTemporaryFile( prefix=remote.filename + ".part_", dir=folder_path, delete=False ) + temp_file.close() try: temp_file_path = Path(temp_file.name) while True: @@ -1512,7 +1513,6 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): f"differs from expected ({remote.checksum})." ) except BaseException: - temp_file.close() os.unlink(temp_file.name) raise From 77ce36e4edb479fa4e3512230b8d3cc6efee4fce Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 12:19:25 +0200 Subject: [PATCH 08/29] Update example on feature engineering with Polars for bike sharing demand to use fetch_file --- .../plot_time_series_lagged_features.py | 857 +++++++++--------- 1 file changed, 432 insertions(+), 425 deletions(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index 9159825cbbd43..28f210ef72496 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -1,425 +1,432 @@ -""" -=========================================== -Lagged features for time series forecasting -=========================================== - -This example demonstrates how Polars-engineered lagged features can be used -for time series forecasting with -:class:`~sklearn.ensemble.HistGradientBoostingRegressor` on the Bike Sharing -Demand dataset. - -See the example on -:ref:`sphx_glr_auto_examples_applications_plot_cyclical_feature_engineering.py` -for some data exploration on this dataset and a demo on periodic feature -engineering. - -""" - -# %% -# Analyzing the Bike Sharing Demand dataset -# ----------------------------------------- -# -# We start by loading the data from the OpenML repository -# as a pandas dataframe. This will be replaced with Polars -# once `fetch_openml` adds a native support for it. -# We convert to Polars for feature engineering, as it automatically caches -# common subexpressions which are reused in multiple expressions -# (like `pl.col("count").shift(1)` below). See -# https://docs.pola.rs/user-guide/lazy/optimizations/ for more information. - -import numpy as np -import polars as pl - -from sklearn.datasets import fetch_openml - -pl.Config.set_fmt_str_lengths(20) - -bike_sharing = fetch_openml( - "Bike_Sharing_Demand", version=2, as_frame=True, parser="pandas" -) -df = bike_sharing.frame -df = pl.DataFrame({col: df[col].to_numpy() for col in df.columns}) - -# %% -# Next, we take a look at the statistical summary of the dataset -# so that we can better understand the data that we are working with. -import polars.selectors as cs - -summary = df.select(cs.numeric()).describe() -summary - -# %% -# Let us look at the count of the seasons `"fall"`, `"spring"`, `"summer"` -# and `"winter"` present in the dataset to confirm they are balanced. - -import matplotlib.pyplot as plt - -df["season"].value_counts() - - -# %% -# Generating Polars-engineered lagged features -# -------------------------------------------- -# Let's consider the problem of predicting the demand at the -# next hour given past demands. Since the demand is a continuous -# variable, one could intuitively use any regression model. However, we do -# not have the usual `(X_train, y_train)` dataset. Instead, we just have -# the `y_train` demand data sequentially organized by time. -lagged_df = df.select( - "count", - *[pl.col("count").shift(i).alias(f"lagged_count_{i}h") for i in [1, 2, 3]], - lagged_count_1d=pl.col("count").shift(24), - lagged_count_1d_1h=pl.col("count").shift(24 + 1), - lagged_count_7d=pl.col("count").shift(7 * 24), - lagged_count_7d_1h=pl.col("count").shift(7 * 24 + 1), - lagged_mean_24h=pl.col("count").shift(1).rolling_mean(24), - lagged_max_24h=pl.col("count").shift(1).rolling_max(24), - lagged_min_24h=pl.col("count").shift(1).rolling_min(24), - lagged_mean_7d=pl.col("count").shift(1).rolling_mean(7 * 24), - lagged_max_7d=pl.col("count").shift(1).rolling_max(7 * 24), - lagged_min_7d=pl.col("count").shift(1).rolling_min(7 * 24), -) -lagged_df.tail(10) - -# %% -# Watch out however, the first lines have undefined values because their own -# past is unknown. This depends on how much lag we used: -lagged_df.head(10) - -# %% -# We can now separate the lagged features in a matrix `X` and the target variable -# (the counts to predict) in an array of the same first dimension `y`. -lagged_df = lagged_df.drop_nulls() -X = lagged_df.drop("count") -y = lagged_df["count"] -print("X shape: {}\ny shape: {}".format(X.shape, y.shape)) - -# %% -# Naive evaluation of the next hour bike demand regression -# -------------------------------------------------------- -# Let's randomly split our tabularized dataset to train a gradient -# boosting regression tree (GBRT) model and evaluate it using Mean -# Absolute Percentage Error (MAPE). If our model is aimed at forecasting -# (i.e., predicting future data from past data), we should not use training -# data that are ulterior to the testing data. In time series machine learning -# the "i.i.d" (independent and identically distributed) assumption does not -# hold true as the data points are not independent and have a temporal -# relationship. -from sklearn.ensemble import HistGradientBoostingRegressor -from sklearn.model_selection import train_test_split - -X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, random_state=42 -) - -model = HistGradientBoostingRegressor().fit(X_train, y_train) - -# %% -# Taking a look at the performance of the model. -from sklearn.metrics import mean_absolute_percentage_error - -y_pred = model.predict(X_test) -mean_absolute_percentage_error(y_test, y_pred) - -# %% -# Proper next hour forecasting evaluation -# --------------------------------------- -# Let's use a proper evaluation splitting strategies that takes into account -# the temporal structure of the dataset to evaluate our model's ability to -# predict data points in the future (to avoid cheating by reading values from -# the lagged features in the training set). -from sklearn.model_selection import TimeSeriesSplit - -ts_cv = TimeSeriesSplit( - n_splits=3, # to keep the notebook fast enough on common laptops - gap=48, # 2 days data gap between train and test - max_train_size=10000, # keep train sets of comparable sizes - test_size=3000, # for 2 or 3 digits of precision in scores -) -all_splits = list(ts_cv.split(X, y)) - -# %% -# Training the model and evaluating its performance based on MAPE. -train_idx, test_idx = all_splits[0] -X_train, X_test = X[train_idx, :], X[test_idx, :] -y_train, y_test = y[train_idx], y[test_idx] - -model = HistGradientBoostingRegressor().fit(X_train, y_train) -y_pred = model.predict(X_test) -mean_absolute_percentage_error(y_test, y_pred) - -# %% -# The generalization error measured via a shuffled trained test split -# is too optimistic. The generalization via a time-based split is likely to -# be more representative of the true performance of the regression model. -# Let's assess this variability of our error evaluation with proper -# cross-validation: -from sklearn.model_selection import cross_val_score - -cv_mape_scores = -cross_val_score( - model, X, y, cv=ts_cv, scoring="neg_mean_absolute_percentage_error" -) -cv_mape_scores - -# %% -# The variability across splits is quite large! In a real life setting -# it would be advised to use more splits to better assess the variability. -# Let's report the mean CV scores and their standard deviation from now on. -print(f"CV MAPE: {cv_mape_scores.mean():.3f} ± {cv_mape_scores.std():.3f}") - -# %% -# We can compute several combinations of evaluation metrics and loss functions, -# which are reported a bit below. -from collections import defaultdict - -from sklearn.metrics import ( - make_scorer, - mean_absolute_error, - mean_pinball_loss, - root_mean_squared_error, -) -from sklearn.model_selection import cross_validate - - -def consolidate_scores(cv_results, scores, metric): - if metric == "MAPE": - scores[metric].append(f"{value.mean():.2f} ± {value.std():.2f}") - else: - scores[metric].append(f"{value.mean():.1f} ± {value.std():.1f}") - - return scores - - -scoring = { - "MAPE": make_scorer(mean_absolute_percentage_error), - "RMSE": make_scorer(root_mean_squared_error), - "MAE": make_scorer(mean_absolute_error), - "pinball_loss_05": make_scorer(mean_pinball_loss, alpha=0.05), - "pinball_loss_50": make_scorer(mean_pinball_loss, alpha=0.50), - "pinball_loss_95": make_scorer(mean_pinball_loss, alpha=0.95), -} -loss_functions = ["squared_error", "poisson", "absolute_error"] -scores = defaultdict(list) -for loss_func in loss_functions: - model = HistGradientBoostingRegressor(loss=loss_func) - cv_results = cross_validate( - model, - X, - y, - cv=ts_cv, - scoring=scoring, - n_jobs=2, - ) - time = cv_results["fit_time"] - scores["loss"].append(loss_func) - scores["fit_time"].append(f"{time.mean():.2f} ± {time.std():.2f} s") - - for key, value in cv_results.items(): - if key.startswith("test_"): - metric = key.split("test_")[1] - scores = consolidate_scores(cv_results, scores, metric) - - -# %% -# Modeling predictive uncertainty via quantile regression -# ------------------------------------------------------- -# Instead of modeling the expected value of the distribution of -# :math:`Y|X` like the least squares and Poisson losses do, one could try to -# estimate quantiles of the conditional distribution. -# -# :math:`Y|X=x_i` is expected to be a random variable for a given data point -# :math:`x_i` because we expect that the number of rentals cannot be 100% -# accurately predicted from the features. It can be influenced by other -# variables not properly captured by the existing lagged features. For -# instance whether or not it will rain in the next hour cannot be fully -# anticipated from the past hours bike rental data. This is what we -# call aleatoric uncertainty. -# -# Quantile regression makes it possible to give a finer description of that -# distribution without making strong assumptions on its shape. -quantile_list = [0.05, 0.5, 0.95] - -for quantile in quantile_list: - model = HistGradientBoostingRegressor(loss="quantile", quantile=quantile) - cv_results = cross_validate( - model, - X, - y, - cv=ts_cv, - scoring=scoring, - n_jobs=2, - ) - time = cv_results["fit_time"] - scores["fit_time"].append(f"{time.mean():.2f} ± {time.std():.2f} s") - - scores["loss"].append(f"quantile {int(quantile*100)}") - for key, value in cv_results.items(): - if key.startswith("test_"): - metric = key.split("test_")[1] - scores = consolidate_scores(cv_results, scores, metric) - -scores_df = pl.DataFrame(scores) -scores_df - - -# %% -# Let us take a look at the losses that minimise each metric. -def min_arg(col): - col_split = pl.col(col).str.split(" ") - return pl.arg_sort_by( - col_split.list.get(0).cast(pl.Float64), - col_split.list.get(2).cast(pl.Float64), - ).first() - - -scores_df.select( - pl.col("loss").get(min_arg(col_name)).alias(col_name) - for col_name in scores_df.columns - if col_name != "loss" -) - -# %% -# Even if the score distributions overlap due to the variance in the dataset, -# it is true that the average RMSE is lower when `loss="squared_error"`, whereas -# the average MAPE is lower when `loss="absolute_error"` as expected. That is -# also the case for the Mean Pinball Loss with the quantiles 5 and 95. The score -# corresponding to the 50 quantile loss is overlapping with the score obtained -# by minimizing other loss functions, which is also the case for the MAE. -# -# A qualitative look at the predictions -# ------------------------------------- -# We can now visualize the performance of the model with regards -# to the 5th percentile, median and the 95th percentile: -all_splits = list(ts_cv.split(X, y)) -train_idx, test_idx = all_splits[0] - -X_train, X_test = X[train_idx, :], X[test_idx, :] -y_train, y_test = y[train_idx], y[test_idx] - -max_iter = 50 -gbrt_mean_poisson = HistGradientBoostingRegressor(loss="poisson", max_iter=max_iter) -gbrt_mean_poisson.fit(X_train, y_train) -mean_predictions = gbrt_mean_poisson.predict(X_test) - -gbrt_median = HistGradientBoostingRegressor( - loss="quantile", quantile=0.5, max_iter=max_iter -) -gbrt_median.fit(X_train, y_train) -median_predictions = gbrt_median.predict(X_test) - -gbrt_percentile_5 = HistGradientBoostingRegressor( - loss="quantile", quantile=0.05, max_iter=max_iter -) -gbrt_percentile_5.fit(X_train, y_train) -percentile_5_predictions = gbrt_percentile_5.predict(X_test) - -gbrt_percentile_95 = HistGradientBoostingRegressor( - loss="quantile", quantile=0.95, max_iter=max_iter -) -gbrt_percentile_95.fit(X_train, y_train) -percentile_95_predictions = gbrt_percentile_95.predict(X_test) - -# %% -# We can now take a look at the predictions made by the regression models: -last_hours = slice(-96, None) -fig, ax = plt.subplots(figsize=(15, 7)) -plt.title("Predictions by regression models") -ax.plot( - y_test[last_hours], - "x-", - alpha=0.2, - label="Actual demand", - color="black", -) -ax.plot( - median_predictions[last_hours], - "^-", - label="GBRT median", -) -ax.plot( - mean_predictions[last_hours], - "x-", - label="GBRT mean (Poisson)", -) -ax.fill_between( - np.arange(96), - percentile_5_predictions[last_hours], - percentile_95_predictions[last_hours], - alpha=0.3, - label="GBRT 90% interval", -) -_ = ax.legend() - -# %% -# Here it's interesting to notice that the blue area between the 5% and 95% -# percentile estimators has a width that varies with the time of the day: -# -# - At night, the blue band is much narrower: the pair of models is quite -# certain that there will be a small number of bike rentals. And furthermore -# these seem correct in the sense that the actual demand stays in that blue -# band. -# - During the day, the blue band is much wider: the uncertainty grows, probably -# because of the variability of the weather that can have a very large impact, -# especially on week-ends. -# - We can also see that during week-days, the commute pattern is still visible in -# the 5% and 95% estimations. -# - Finally, it is expected that 10% of the time, the actual demand does not lie -# between the 5% and 95% percentile estimates. On this test span, the actual -# demand seems to be higher, especially during the rush hours. It might reveal that -# our 95% percentile estimator underestimates the demand peaks. This could be be -# quantitatively confirmed by computing empirical coverage numbers as done in -# the :ref:`calibration of confidence intervals `. -# -# Looking at the performance of non-linear regression models vs -# the best models: -from sklearn.metrics import PredictionErrorDisplay - -fig, axes = plt.subplots(ncols=3, figsize=(15, 6), sharey=True) -fig.suptitle("Non-linear regression models") -predictions = [ - median_predictions, - percentile_5_predictions, - percentile_95_predictions, -] -labels = [ - "Median", - "5th percentile", - "95th percentile", -] -for ax, pred, label in zip(axes, predictions, labels): - PredictionErrorDisplay.from_predictions( - y_true=y_test, - y_pred=pred, - kind="residual_vs_predicted", - scatter_kwargs={"alpha": 0.3}, - ax=ax, - ) - ax.set(xlabel="Predicted demand", ylabel="True demand") - ax.legend(["Best model", label]) - -plt.show() - -# %% -# Conclusion -# ---------- -# Through this example we explored time series forecasting using lagged -# features. We compared a naive regression (using the standardized -# :class:`~sklearn.model_selection.train_test_split`) with a proper time -# series evaluation strategy using -# :class:`~sklearn.model_selection.TimeSeriesSplit`. We observed that the -# model trained using :class:`~sklearn.model_selection.train_test_split`, -# having a default value of `shuffle` set to `True` produced an overly -# optimistic Mean Average Percentage Error (MAPE). The results -# produced from the time-based split better represent the performance -# of our time-series regression model. We also analyzed the predictive uncertainty -# of our model via Quantile Regression. Predictions based on the 5th and -# 95th percentile using `loss="quantile"` provide us with a quantitative estimate -# of the uncertainty of the forecasts made by our time series regression model. -# Uncertainty estimation can also be performed -# using `MAPIE `_, -# that provides an implementation based on recent work on conformal prediction -# methods and estimates both aleatoric and epistemic uncertainty at the same time. -# Furthermore, functionalities provided -# by `sktime `_ -# can be used to extend scikit-learn estimators by making use of recursive time -# series forecasting, that enables dynamic predictions of future values. +""" +=========================================== +Lagged features for time series forecasting +=========================================== + +This example demonstrates how Polars-engineered lagged features can be used +for time series forecasting with +:class:`~sklearn.ensemble.HistGradientBoostingRegressor` on the Bike Sharing +Demand dataset. + +See the example on +:ref:`sphx_glr_auto_examples_applications_plot_cyclical_feature_engineering.py` +for some data exploration on this dataset and a demo on periodic feature +engineering. + +""" + +# %% +# Analyzing the Bike Sharing Demand dataset +# ----------------------------------------- +# +# We start by loading the data from the OpenML repository as a raw parquet file +# to illustrate how to work with arbitrary parquet file instead of hiding this +# step in a convenience too such as `sklearn.datasets.fetch_openml`. +import numpy as np +import polars as pl + +from sklearn.datasets import fetch_file + +pl.Config.set_fmt_str_lengths(20) + +# Direct download of the parquet file of the Bike Sharing Demand v7 dataset on +# openml.org: +bike_sharing_data_file = fetch_file( + "https://openml1.win.tue.nl/datasets/0004/44063/dataset_44063.pq", + sha256="d120af76829af0d256338dc6dd4be5df4fd1f35bf3a283cab66a51c1c6abd06a", +) +bike_sharing_data_file + +# %% +# We load the parquet file with Polars for feature engineering, as it +# automatically caches common subexpressions which are reused in multiple +# expressions (like `pl.col("count").shift(1)` below). See +# https://docs.pola.rs/user-guide/lazy/optimizations/ for more information. + +df = pl.read_parquet(bike_sharing_data_file) + +# %% +# Next, we take a look at the statistical summary of the dataset +# so that we can better understand the data that we are working with. +import polars.selectors as cs + +summary = df.select(cs.numeric()).describe() +summary + +# %% +# Let us look at the count of the seasons `"fall"`, `"spring"`, `"summer"` +# and `"winter"` present in the dataset to confirm they are balanced. + +import matplotlib.pyplot as plt + +df["season"].value_counts() + + +# %% +# Generating Polars-engineered lagged features +# -------------------------------------------- +# Let's consider the problem of predicting the demand at the +# next hour given past demands. Since the demand is a continuous +# variable, one could intuitively use any regression model. However, we do +# not have the usual `(X_train, y_train)` dataset. Instead, we just have +# the `y_train` demand data sequentially organized by time. +lagged_df = df.select( + "count", + *[pl.col("count").shift(i).alias(f"lagged_count_{i}h") for i in [1, 2, 3]], + lagged_count_1d=pl.col("count").shift(24), + lagged_count_1d_1h=pl.col("count").shift(24 + 1), + lagged_count_7d=pl.col("count").shift(7 * 24), + lagged_count_7d_1h=pl.col("count").shift(7 * 24 + 1), + lagged_mean_24h=pl.col("count").shift(1).rolling_mean(24), + lagged_max_24h=pl.col("count").shift(1).rolling_max(24), + lagged_min_24h=pl.col("count").shift(1).rolling_min(24), + lagged_mean_7d=pl.col("count").shift(1).rolling_mean(7 * 24), + lagged_max_7d=pl.col("count").shift(1).rolling_max(7 * 24), + lagged_min_7d=pl.col("count").shift(1).rolling_min(7 * 24), +) +lagged_df.tail(10) + +# %% +# Watch out however, the first lines have undefined values because their own +# past is unknown. This depends on how much lag we used: +lagged_df.head(10) + +# %% +# We can now separate the lagged features in a matrix `X` and the target variable +# (the counts to predict) in an array of the same first dimension `y`. +lagged_df = lagged_df.drop_nulls() +X = lagged_df.drop("count") +y = lagged_df["count"] +print("X shape: {}\ny shape: {}".format(X.shape, y.shape)) + +# %% +# Naive evaluation of the next hour bike demand regression +# -------------------------------------------------------- +# Let's randomly split our tabularized dataset to train a gradient +# boosting regression tree (GBRT) model and evaluate it using Mean +# Absolute Percentage Error (MAPE). If our model is aimed at forecasting +# (i.e., predicting future data from past data), we should not use training +# data that are ulterior to the testing data. In time series machine learning +# the "i.i.d" (independent and identically distributed) assumption does not +# hold true as the data points are not independent and have a temporal +# relationship. +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.model_selection import train_test_split + +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +model = HistGradientBoostingRegressor().fit(X_train, y_train) + +# %% +# Taking a look at the performance of the model. +from sklearn.metrics import mean_absolute_percentage_error + +y_pred = model.predict(X_test) +mean_absolute_percentage_error(y_test, y_pred) + +# %% +# Proper next hour forecasting evaluation +# --------------------------------------- +# Let's use a proper evaluation splitting strategies that takes into account +# the temporal structure of the dataset to evaluate our model's ability to +# predict data points in the future (to avoid cheating by reading values from +# the lagged features in the training set). +from sklearn.model_selection import TimeSeriesSplit + +ts_cv = TimeSeriesSplit( + n_splits=3, # to keep the notebook fast enough on common laptops + gap=48, # 2 days data gap between train and test + max_train_size=10000, # keep train sets of comparable sizes + test_size=3000, # for 2 or 3 digits of precision in scores +) +all_splits = list(ts_cv.split(X, y)) + +# %% +# Training the model and evaluating its performance based on MAPE. +train_idx, test_idx = all_splits[0] +X_train, X_test = X[train_idx, :], X[test_idx, :] +y_train, y_test = y[train_idx], y[test_idx] + +model = HistGradientBoostingRegressor().fit(X_train, y_train) +y_pred = model.predict(X_test) +mean_absolute_percentage_error(y_test, y_pred) + +# %% +# The generalization error measured via a shuffled trained test split +# is too optimistic. The generalization via a time-based split is likely to +# be more representative of the true performance of the regression model. +# Let's assess this variability of our error evaluation with proper +# cross-validation: +from sklearn.model_selection import cross_val_score + +cv_mape_scores = -cross_val_score( + model, X, y, cv=ts_cv, scoring="neg_mean_absolute_percentage_error" +) +cv_mape_scores + +# %% +# The variability across splits is quite large! In a real life setting +# it would be advised to use more splits to better assess the variability. +# Let's report the mean CV scores and their standard deviation from now on. +print(f"CV MAPE: {cv_mape_scores.mean():.3f} ± {cv_mape_scores.std():.3f}") + +# %% +# We can compute several combinations of evaluation metrics and loss functions, +# which are reported a bit below. +from collections import defaultdict + +from sklearn.metrics import ( + make_scorer, + mean_absolute_error, + mean_pinball_loss, + root_mean_squared_error, +) +from sklearn.model_selection import cross_validate + + +def consolidate_scores(cv_results, scores, metric): + if metric == "MAPE": + scores[metric].append(f"{value.mean():.2f} ± {value.std():.2f}") + else: + scores[metric].append(f"{value.mean():.1f} ± {value.std():.1f}") + + return scores + + +scoring = { + "MAPE": make_scorer(mean_absolute_percentage_error), + "RMSE": make_scorer(root_mean_squared_error), + "MAE": make_scorer(mean_absolute_error), + "pinball_loss_05": make_scorer(mean_pinball_loss, alpha=0.05), + "pinball_loss_50": make_scorer(mean_pinball_loss, alpha=0.50), + "pinball_loss_95": make_scorer(mean_pinball_loss, alpha=0.95), +} +loss_functions = ["squared_error", "poisson", "absolute_error"] +scores = defaultdict(list) +for loss_func in loss_functions: + model = HistGradientBoostingRegressor(loss=loss_func) + cv_results = cross_validate( + model, + X, + y, + cv=ts_cv, + scoring=scoring, + n_jobs=2, + ) + time = cv_results["fit_time"] + scores["loss"].append(loss_func) + scores["fit_time"].append(f"{time.mean():.2f} ± {time.std():.2f} s") + + for key, value in cv_results.items(): + if key.startswith("test_"): + metric = key.split("test_")[1] + scores = consolidate_scores(cv_results, scores, metric) + + +# %% +# Modeling predictive uncertainty via quantile regression +# ------------------------------------------------------- +# Instead of modeling the expected value of the distribution of +# :math:`Y|X` like the least squares and Poisson losses do, one could try to +# estimate quantiles of the conditional distribution. +# +# :math:`Y|X=x_i` is expected to be a random variable for a given data point +# :math:`x_i` because we expect that the number of rentals cannot be 100% +# accurately predicted from the features. It can be influenced by other +# variables not properly captured by the existing lagged features. For +# instance whether or not it will rain in the next hour cannot be fully +# anticipated from the past hours bike rental data. This is what we +# call aleatoric uncertainty. +# +# Quantile regression makes it possible to give a finer description of that +# distribution without making strong assumptions on its shape. +quantile_list = [0.05, 0.5, 0.95] + +for quantile in quantile_list: + model = HistGradientBoostingRegressor(loss="quantile", quantile=quantile) + cv_results = cross_validate( + model, + X, + y, + cv=ts_cv, + scoring=scoring, + n_jobs=2, + ) + time = cv_results["fit_time"] + scores["fit_time"].append(f"{time.mean():.2f} ± {time.std():.2f} s") + + scores["loss"].append(f"quantile {int(quantile*100)}") + for key, value in cv_results.items(): + if key.startswith("test_"): + metric = key.split("test_")[1] + scores = consolidate_scores(cv_results, scores, metric) + +scores_df = pl.DataFrame(scores) +scores_df + + +# %% +# Let us take a look at the losses that minimise each metric. +def min_arg(col): + col_split = pl.col(col).str.split(" ") + return pl.arg_sort_by( + col_split.list.get(0).cast(pl.Float64), + col_split.list.get(2).cast(pl.Float64), + ).first() + + +scores_df.select( + pl.col("loss").get(min_arg(col_name)).alias(col_name) + for col_name in scores_df.columns + if col_name != "loss" +) + +# %% +# Even if the score distributions overlap due to the variance in the dataset, +# it is true that the average RMSE is lower when `loss="squared_error"`, whereas +# the average MAPE is lower when `loss="absolute_error"` as expected. That is +# also the case for the Mean Pinball Loss with the quantiles 5 and 95. The score +# corresponding to the 50 quantile loss is overlapping with the score obtained +# by minimizing other loss functions, which is also the case for the MAE. +# +# A qualitative look at the predictions +# ------------------------------------- +# We can now visualize the performance of the model with regards +# to the 5th percentile, median and the 95th percentile: +all_splits = list(ts_cv.split(X, y)) +train_idx, test_idx = all_splits[0] + +X_train, X_test = X[train_idx, :], X[test_idx, :] +y_train, y_test = y[train_idx], y[test_idx] + +max_iter = 50 +gbrt_mean_poisson = HistGradientBoostingRegressor(loss="poisson", max_iter=max_iter) +gbrt_mean_poisson.fit(X_train, y_train) +mean_predictions = gbrt_mean_poisson.predict(X_test) + +gbrt_median = HistGradientBoostingRegressor( + loss="quantile", quantile=0.5, max_iter=max_iter +) +gbrt_median.fit(X_train, y_train) +median_predictions = gbrt_median.predict(X_test) + +gbrt_percentile_5 = HistGradientBoostingRegressor( + loss="quantile", quantile=0.05, max_iter=max_iter +) +gbrt_percentile_5.fit(X_train, y_train) +percentile_5_predictions = gbrt_percentile_5.predict(X_test) + +gbrt_percentile_95 = HistGradientBoostingRegressor( + loss="quantile", quantile=0.95, max_iter=max_iter +) +gbrt_percentile_95.fit(X_train, y_train) +percentile_95_predictions = gbrt_percentile_95.predict(X_test) + +# %% +# We can now take a look at the predictions made by the regression models: +last_hours = slice(-96, None) +fig, ax = plt.subplots(figsize=(15, 7)) +plt.title("Predictions by regression models") +ax.plot( + y_test[last_hours], + "x-", + alpha=0.2, + label="Actual demand", + color="black", +) +ax.plot( + median_predictions[last_hours], + "^-", + label="GBRT median", +) +ax.plot( + mean_predictions[last_hours], + "x-", + label="GBRT mean (Poisson)", +) +ax.fill_between( + np.arange(96), + percentile_5_predictions[last_hours], + percentile_95_predictions[last_hours], + alpha=0.3, + label="GBRT 90% interval", +) +_ = ax.legend() + +# %% +# Here it's interesting to notice that the blue area between the 5% and 95% +# percentile estimators has a width that varies with the time of the day: +# +# - At night, the blue band is much narrower: the pair of models is quite +# certain that there will be a small number of bike rentals. And furthermore +# these seem correct in the sense that the actual demand stays in that blue +# band. +# - During the day, the blue band is much wider: the uncertainty grows, probably +# because of the variability of the weather that can have a very large impact, +# especially on week-ends. +# - We can also see that during week-days, the commute pattern is still visible in +# the 5% and 95% estimations. +# - Finally, it is expected that 10% of the time, the actual demand does not lie +# between the 5% and 95% percentile estimates. On this test span, the actual +# demand seems to be higher, especially during the rush hours. It might reveal that +# our 95% percentile estimator underestimates the demand peaks. This could be be +# quantitatively confirmed by computing empirical coverage numbers as done in +# the :ref:`calibration of confidence intervals `. +# +# Looking at the performance of non-linear regression models vs +# the best models: +from sklearn.metrics import PredictionErrorDisplay + +fig, axes = plt.subplots(ncols=3, figsize=(15, 6), sharey=True) +fig.suptitle("Non-linear regression models") +predictions = [ + median_predictions, + percentile_5_predictions, + percentile_95_predictions, +] +labels = [ + "Median", + "5th percentile", + "95th percentile", +] +for ax, pred, label in zip(axes, predictions, labels): + PredictionErrorDisplay.from_predictions( + y_true=y_test, + y_pred=pred, + kind="residual_vs_predicted", + scatter_kwargs={"alpha": 0.3}, + ax=ax, + ) + ax.set(xlabel="Predicted demand", ylabel="True demand") + ax.legend(["Best model", label]) + +plt.show() + +# %% +# Conclusion +# ---------- +# Through this example we explored time series forecasting using lagged +# features. We compared a naive regression (using the standardized +# :class:`~sklearn.model_selection.train_test_split`) with a proper time +# series evaluation strategy using +# :class:`~sklearn.model_selection.TimeSeriesSplit`. We observed that the +# model trained using :class:`~sklearn.model_selection.train_test_split`, +# having a default value of `shuffle` set to `True` produced an overly +# optimistic Mean Average Percentage Error (MAPE). The results +# produced from the time-based split better represent the performance +# of our time-series regression model. We also analyzed the predictive uncertainty +# of our model via Quantile Regression. Predictions based on the 5th and +# 95th percentile using `loss="quantile"` provide us with a quantitative estimate +# of the uncertainty of the forecasts made by our time series regression model. +# Uncertainty estimation can also be performed +# using `MAPIE `_, +# that provides an implementation based on recent work on conformal prediction +# methods and estimates both aleatoric and epistemic uncertainty at the same time. +# Furthermore, functionalities provided +# by `sktime `_ +# can be used to extend scikit-learn estimators by making use of recursive time +# series forecasting, that enables dynamic predictions of future values. + +# %% From c7a35e05d97df443cb6ad0dfdc4029451b07cba6 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 13:55:00 +0200 Subject: [PATCH 09/29] Make expected warning message OS independent --- sklearn/datasets/tests/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 6157589d052ae..78961dd8b183d 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -567,7 +567,7 @@ def test_fetch_file_with_sha256(monkeypatch, tmpdir): # Corrupting the local data should yield a warning and trigger a new download: fetched_file_path.write_text("corruped contents", encoding="utf-8") expected_msg = ( - r"SHA256 checksum of existing local file at .*client_side/data.jsonl " + r"SHA256 checksum of existing local file at .*data.jsonl " rf"\(.*\) differs from expected \({expected_sha256}\): " r"re-downloading from https://example.com/data.jsonl \." ) From c09ddff34fcc81e8d045b1f828420cc298949342 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 13:56:21 +0200 Subject: [PATCH 10/29] Shorter warning message --- sklearn/datasets/_base.py | 2 +- sklearn/datasets/tests/test_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 29fd786a60c41..f665a8b8c51b9 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1483,7 +1483,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): return file_path else: warnings.warn( - f"SHA256 checksum of existing local file at {str(file_path)} " + f"SHA256 checksum of existing local file {file_path.name} " f"({checksum}) differs from expected ({remote.checksum}): " f"re-downloading from {remote.url} ." ) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 78961dd8b183d..a38d4091b4205 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -567,7 +567,7 @@ def test_fetch_file_with_sha256(monkeypatch, tmpdir): # Corrupting the local data should yield a warning and trigger a new download: fetched_file_path.write_text("corruped contents", encoding="utf-8") expected_msg = ( - r"SHA256 checksum of existing local file at .*data.jsonl " + r"SHA256 checksum of existing local file data.jsonl " rf"\(.*\) differs from expected \({expected_sha256}\): " r"re-downloading from https://example.com/data.jsonl \." ) From 8368fe7133733a7bff0fa920ce6a9544ff584d6e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 14:02:27 +0200 Subject: [PATCH 11/29] Improve phrasing in the first example cell --- .../applications/plot_time_series_lagged_features.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index 28f210ef72496..70823f74f17bf 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -20,8 +20,11 @@ # ----------------------------------------- # # We start by loading the data from the OpenML repository as a raw parquet file -# to illustrate how to work with arbitrary parquet file instead of hiding this +# to illustrate how to work with an arbitrary parquet file instead of hiding this # step in a convenience too such as `sklearn.datasets.fetch_openml`. +# +# The URL of the parquet file can be found in the JSON description of the +# Bike Sharing Demand v7 dataset on openml.org. import numpy as np import polars as pl @@ -29,8 +32,6 @@ pl.Config.set_fmt_str_lengths(20) -# Direct download of the parquet file of the Bike Sharing Demand v7 dataset on -# openml.org: bike_sharing_data_file = fetch_file( "https://openml1.win.tue.nl/datasets/0004/44063/dataset_44063.pq", sha256="d120af76829af0d256338dc6dd4be5df4fd1f35bf3a283cab66a51c1c6abd06a", @@ -38,7 +39,7 @@ bike_sharing_data_file # %% -# We load the parquet file with Polars for feature engineering, as it +# We load the parquet file with Polars for feature engineering. Polas # automatically caches common subexpressions which are reused in multiple # expressions (like `pl.col("count").shift(1)` below). See # https://docs.pola.rs/user-guide/lazy/optimizations/ for more information. From 5aa07b771f708b65f14b1fae74d9c736bc23e7f8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 14:43:32 +0200 Subject: [PATCH 12/29] Simplify _slugify --- sklearn/datasets/_base.py | 14 ++++---------- sklearn/datasets/tests/test_base.py | 3 --- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index f665a8b8c51b9..3ba0d06fb6d68 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1525,7 +1525,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): return file_path -def _slugify(value, allow_unicode=False): +def _slugify(value): """Derive a name that is safe to use as filename from the given string. Adapted from @@ -1539,15 +1539,9 @@ def _slugify(value, allow_unicode=False): Note: this version keeps "." characters unchanged contrary to the django version and replace other un-authorized characters by "_". """ - value = str(value) - if allow_unicode: - value = unicodedata.normalize("NFKC", value) - else: - value = ( - unicodedata.normalize("NFKD", value) - .encode("ascii", "ignore") - .decode("ascii") - ) + value = ( + unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") + ) value = re.sub(r"[^.\w\s-]", "_", value.lower()) value = re.sub(r"_+", "_", value) return re.sub(r"[-\s]+", "-", value).strip("-_") diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index a38d4091b4205..b2c943ad4b9c0 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -619,6 +619,3 @@ def test_fetch_file_with_sha256(monkeypatch, tmpdir): folder=client_side, sha256=non_matching_sha256, ) - # The local file should not have been deleted. - assert client_side.join("data.jsonl").read_text(encoding="utf-8") == server_data - assert urlretrieve_mock.call_count == 3 From 5da0feb39f87ed6f208697cda3824f82f6a7bf6b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 17:26:22 +0200 Subject: [PATCH 13/29] Apply suggestions from code review Co-authored-by: Guillaume Lemaitre --- examples/applications/plot_time_series_lagged_features.py | 4 ++-- sklearn/datasets/_base.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index 70823f74f17bf..d7d9e23113a56 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -21,7 +21,7 @@ # # We start by loading the data from the OpenML repository as a raw parquet file # to illustrate how to work with an arbitrary parquet file instead of hiding this -# step in a convenience too such as `sklearn.datasets.fetch_openml`. +# step in a convenience tool such as `sklearn.datasets.fetch_openml`. # # The URL of the parquet file can be found in the JSON description of the # Bike Sharing Demand v7 dataset on openml.org. @@ -39,7 +39,7 @@ bike_sharing_data_file # %% -# We load the parquet file with Polars for feature engineering. Polas +# We load the parquet file with Polars for feature engineering. Polars # automatically caches common subexpressions which are reused in multiple # expressions (like `pl.col("count").shift(1)` below). See # https://docs.pola.rs/user-guide/lazy/optimizations/ for more information. diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 3ba0d06fb6d68..90ed1803f2f31 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1572,7 +1572,7 @@ def _derive_folder_and_filename_from_url(url): def fetch_file( url, folder=None, local_filename=None, sha256=None, n_retries=3, delay=1 ): - """Fetch a file from the web. + """Fetch a file from the web if not already present in the local folder. If the file already exists locally (and the SHA256 checksums match when provided), the path to the local file is returned without re-downloading. From b9df5b227799e4ba0275a894ebf256b636d626f6 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 17:28:27 +0200 Subject: [PATCH 14/29] Add URL to openml UI for the Bike Sharing Dataset --- examples/applications/plot_time_series_lagged_features.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index d7d9e23113a56..abbed9a54cc36 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -24,7 +24,8 @@ # step in a convenience tool such as `sklearn.datasets.fetch_openml`. # # The URL of the parquet file can be found in the JSON description of the -# Bike Sharing Demand v7 dataset on openml.org. +# Bike Sharing Demand v7 dataset on openml.org +# (https://openml.org/search?type=data&status=active&id=44063). import numpy as np import polars as pl From 4e50efa89897667eb53ad2264c17d79a0f68a678 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 17:30:08 +0200 Subject: [PATCH 15/29] Explain the use of the sha256 argument --- examples/applications/plot_time_series_lagged_features.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index abbed9a54cc36..bbd53de00db6e 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -24,8 +24,11 @@ # step in a convenience tool such as `sklearn.datasets.fetch_openml`. # # The URL of the parquet file can be found in the JSON description of the -# Bike Sharing Demand v7 dataset on openml.org +# Bike Sharing Demand dataset with id 44063 on openml.org # (https://openml.org/search?type=data&status=active&id=44063). +# +# The `sha256` hash of the file is also provided to ensure the integrity of the +# downloaded file. import numpy as np import polars as pl From fa429e1a689a9a762a9726311e5a86b03d807462 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 17:31:04 +0200 Subject: [PATCH 16/29] Trim useless empty cell. Co-authored-by: Guillaume Lemaitre --- examples/applications/plot_time_series_lagged_features.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index bbd53de00db6e..fae1b7705a815 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -434,4 +434,3 @@ def min_arg(col): # can be used to extend scikit-learn estimators by making use of recursive time # series forecasting, that enables dynamic predictions of future values. -# %% From 811dc3cfb08d167910bf5be8542816a5ee03e086 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Jun 2024 17:44:14 +0200 Subject: [PATCH 17/29] Trailing line --- examples/applications/plot_time_series_lagged_features.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/applications/plot_time_series_lagged_features.py b/examples/applications/plot_time_series_lagged_features.py index fae1b7705a815..2efc12acae276 100644 --- a/examples/applications/plot_time_series_lagged_features.py +++ b/examples/applications/plot_time_series_lagged_features.py @@ -433,4 +433,3 @@ def min_arg(col): # by `sktime `_ # can be used to extend scikit-learn estimators by making use of recursive time # series forecasting, that enables dynamic predictions of future values. - From 91812b5066c25eb12f518efc35b694dcf529765e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jul 2024 10:35:48 +0200 Subject: [PATCH 18/29] Update the docstring of _slugify to better describe the actual behavior --- sklearn/datasets/_base.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 90ed1803f2f31..5f1e99764dbb0 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1531,13 +1531,10 @@ def _slugify(value): Adapted from https://github.com/django/django/blob/master/django/utils/text.py - Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated - dashes to single dashes. Remove characters that aren't alphanumerics, - underscores, or hyphens. Convert to lowercase. Also strip leading and + Convert to ASCII, convert spaces or repeated dashes to single dashes. + Replace characters that aren't alphanumerics, underscores, hyphens or + periods by underscores. Convert to lowercase. Also strip leading and trailing whitespace, dashes, and underscores. - - Note: this version keeps "." characters unchanged contrary to the django - version and replace other un-authorized characters by "_". """ value = ( unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") From f465998cd6bf00b40abf722d8fe7d90a7ee2cf68 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jul 2024 11:52:46 +0200 Subject: [PATCH 19/29] Improve _derive_folder_and_filename_from_url safety based on feedback from review --- sklearn/datasets/_base.py | 45 +++++++++++++++-------------- sklearn/datasets/tests/test_base.py | 31 ++++++++++++++++++++ 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 5f1e99764dbb0..10bf7330f9ba2 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1525,45 +1525,48 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): return file_path -def _slugify(value): +def _filter_filename(value, filter_dots=True): """Derive a name that is safe to use as filename from the given string. Adapted from https://github.com/django/django/blob/master/django/utils/text.py - Convert to ASCII, convert spaces or repeated dashes to single dashes. - Replace characters that aren't alphanumerics, underscores, hyphens or - periods by underscores. Convert to lowercase. Also strip leading and - trailing whitespace, dashes, and underscores. + Convert spaces or repeated dashes to single dashes. Replace characters that + aren't alphanumerics, underscores, hyphens or dots by underscores. Convert + to lowercase. Also strip leading and trailing whitespace, dashes, and + underscores. """ - value = ( - unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") - ) - value = re.sub(r"[^.\w\s-]", "_", value.lower()) + value = unicodedata.normalize("NFKD", value).lower() + if filter_dots: + value = re.sub(r"[^\w\s-]", "_", value) + else: + value = re.sub(r"[^.\w\s-]", "_", value) value = re.sub(r"_+", "_", value) - return re.sub(r"[-\s]+", "-", value).strip("-_") + value = re.sub(r"-+", "-", value) + return value.strip("-_.") def _derive_folder_and_filename_from_url(url): parsed_url = urlparse(url) + if not parsed_url.hostname: + raise ValueError(f"Invalid URL: {url}") + folder_components = [_filter_filename(parsed_url.hostname, filter_dots=False)] path = parsed_url.path - if not path: - path = "/" if "/" in path: - base_folder, filename = path.rsplit("/", 1) + base_folder, raw_filename = path.rsplit("/", 1) + + base_folder = _filter_filename(base_folder) + if base_folder: + folder_components.append(base_folder) + else: + raw_filename = path + filename = _filter_filename(raw_filename, filter_dots=False) if not filename: filename = "downloaded_file" - base_folder = _slugify(base_folder) - if base_folder: - base_folder = "/" + base_folder - - return ( - _slugify(parsed_url.hostname) + base_folder, - _slugify(filename), - ) + return "/".join(folder_components), filename def fetch_file( diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index b2c943ad4b9c0..7f656e6131e80 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -405,6 +405,12 @@ def test_derive_folder_and_filename_from_url(): assert folder == "example.com" assert filename == "file.tar.gz" + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/نمونه نماینده.data" + ) + assert folder == "example.com" + assert filename == "نمونه نماینده.data" + folder, filename = _derive_folder_and_filename_from_url( "https://example.com/path/to/file.tar.gz" ) @@ -431,6 +437,31 @@ def test_derive_folder_and_filename_from_url(): assert folder == "example.com/path_to" assert filename == "data.json" + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com//some_file.txt" + ) + assert folder == "example.com" + assert filename == "some_file.txt" + + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/!.'.,/some_file.txt" + ) + assert folder == "example.com" + assert filename == "some_file.txt" + + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/a/!.'.,/b/some_file.txt" + ) + assert folder == "example.com/a_b" + assert filename == "some_file.txt" + + folder, filename = _derive_folder_and_filename_from_url("https://example.com/!.'.,") + assert folder == "example.com" + assert filename == "downloaded_file" + + with pytest.raises(ValueError, match="Invalid URL"): + _derive_folder_and_filename_from_url("https:/../") + def _mock_urlretrieve(server_side): def _urlretrieve_mock(url, local_path): From 8c2120aadb0b4cd8309168443be31624140841fe Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jul 2024 14:37:37 +0200 Subject: [PATCH 20/29] Simpler handling of repetition + fix bug in replacing white spaces as dashes --- sklearn/datasets/_base.py | 7 +++---- sklearn/datasets/tests/test_base.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 10bf7330f9ba2..38f441af315bd 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1538,11 +1538,10 @@ def _filter_filename(value, filter_dots=True): """ value = unicodedata.normalize("NFKD", value).lower() if filter_dots: - value = re.sub(r"[^\w\s-]", "_", value) + value = re.sub(r"[^\w\s-]+", "_", value) else: - value = re.sub(r"[^.\w\s-]", "_", value) - value = re.sub(r"_+", "_", value) - value = re.sub(r"-+", "-", value) + value = re.sub(r"[^.\w\s-]+", "_", value) + value = re.sub(r"[\s-]+", "-", value) return value.strip("-_.") diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 7f656e6131e80..d56d387f9a72a 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -409,7 +409,7 @@ def test_derive_folder_and_filename_from_url(): "https://example.com/نمونه نماینده.data" ) assert folder == "example.com" - assert filename == "نمونه نماینده.data" + assert filename == "نمونه-نماینده.data" folder, filename = _derive_folder_and_filename_from_url( "https://example.com/path/to/file.tar.gz" @@ -432,7 +432,7 @@ def test_derive_folder_and_filename_from_url(): assert filename == "data.json" folder, filename = _derive_folder_and_filename_from_url( - "https://example.com/path/@to/data.json#anchor" + "https://example.com/path/@@to/data.json#anchor" ) assert folder == "example.com/path_to" assert filename == "data.json" From f14235df583846f49b6837e6f17d894d5998b648 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jul 2024 14:44:27 +0200 Subject: [PATCH 21/29] Empty commit to trigger PR update From 4854ebadb21e9aad477ad22cbfd81ff83ec68616 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jul 2024 15:48:50 +0200 Subject: [PATCH 22/29] Test .. explicitly and more stripping patterns --- sklearn/datasets/tests/test_base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index d56d387f9a72a..85dfee76f22c6 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -412,7 +412,7 @@ def test_derive_folder_and_filename_from_url(): assert filename == "نمونه-نماینده.data" folder, filename = _derive_folder_and_filename_from_url( - "https://example.com/path/to/file.tar.gz" + "https://example.com/path/to/.file.tar.gz" ) assert folder == "example.com/path_to" assert filename == "file.tar.gz" @@ -432,7 +432,7 @@ def test_derive_folder_and_filename_from_url(): assert filename == "data.json" folder, filename = _derive_folder_and_filename_from_url( - "https://example.com/path/@@to/data.json#anchor" + "https://example.com/path/@@to._/-_.data.json.#anchor" ) assert folder == "example.com/path_to" assert filename == "data.json" @@ -443,6 +443,12 @@ def test_derive_folder_and_filename_from_url(): assert folder == "example.com" assert filename == "some_file.txt" + folder, filename = _derive_folder_and_filename_from_url( + "https://example.com/../some_file.txt" + ) + assert folder == "example.com" + assert filename == "some_file.txt" + folder, filename = _derive_folder_and_filename_from_url( "https://example.com/!.'.,/some_file.txt" ) From c78019f4e9c667b85f91da8090dbca49be54af80 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jul 2024 15:52:12 +0200 Subject: [PATCH 23/29] Better test what Adrin actually suggested --- sklearn/datasets/tests/test_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 85dfee76f22c6..8c5f3618a1d58 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -412,7 +412,7 @@ def test_derive_folder_and_filename_from_url(): assert filename == "نمونه-نماینده.data" folder, filename = _derive_folder_and_filename_from_url( - "https://example.com/path/to/.file.tar.gz" + "https://example.com/path/to-/.file.tar.gz" ) assert folder == "example.com/path_to" assert filename == "file.tar.gz" @@ -444,9 +444,9 @@ def test_derive_folder_and_filename_from_url(): assert filename == "some_file.txt" folder, filename = _derive_folder_and_filename_from_url( - "https://example.com/../some_file.txt" + "http://example/../some_file.txt" ) - assert folder == "example.com" + assert folder == "example" assert filename == "some_file.txt" folder, filename = _derive_folder_and_filename_from_url( From f55da8cc8802a77060d1625f9b84e6a70e26a4ea Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 9 Jul 2024 14:59:23 +0200 Subject: [PATCH 24/29] Better corrupted contents. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Loïc Estève --- sklearn/datasets/tests/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 8c5f3618a1d58..8b5231f68abdd 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -602,7 +602,7 @@ def test_fetch_file_with_sha256(monkeypatch, tmpdir): assert urlretrieve_mock.call_count == 1 # Corrupting the local data should yield a warning and trigger a new download: - fetched_file_path.write_text("corruped contents", encoding="utf-8") + fetched_file_path.write_text("corrupted contents", encoding="utf-8") expected_msg = ( r"SHA256 checksum of existing local file data.jsonl " rf"\(.*\) differs from expected \({expected_sha256}\): " From e6e35a192a592e9cb1bd23bbad28ac448cedbb7d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 9 Jul 2024 17:51:39 +0200 Subject: [PATCH 25/29] Mention slugify in the docstring. --- sklearn/datasets/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 38f441af315bd..871fe5ce49e4b 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1528,7 +1528,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): def _filter_filename(value, filter_dots=True): """Derive a name that is safe to use as filename from the given string. - Adapted from + Adapted from the `slugify` function of django: https://github.com/django/django/blob/master/django/utils/text.py Convert spaces or repeated dashes to single dashes. Replace characters that From 94e556395c102ea634e88ef8396bd835d8272827 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Jul 2024 15:42:15 +0200 Subject: [PATCH 26/29] Explain the logic behid manual deletion and renaming of the temporary file --- sklearn/datasets/_base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 871fe5ce49e4b..fb71daf029c09 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1488,6 +1488,14 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): f"re-downloading from {remote.url} ." ) + # We create a temporary file dedicated to this particular download to avoid + # conflicts with parallel downloads. If the download is successful, the + # temporary file is atomically renamed to the final file path (with + # `shutil.move`). We therefore pass `delete=False` to `NamedTemporaryFile`. + # Otherwise, garbage collecting temp_file would raise an error when + # attempting to delete a file that was already renamed. If the download + # fails or the result does not match the expected SHA256 digest, the + # temporary file is removed manually in the except block. temp_file = NamedTemporaryFile( prefix=remote.filename + ".part_", dir=folder_path, delete=False ) @@ -1512,7 +1520,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): f"The SHA256 checksum of {remote.filename} ({checksum}) " f"differs from expected ({remote.checksum})." ) - except BaseException: + except Exception: os.unlink(temp_file.name) raise From 82211b364bf8d8f3e2b827f63482a4ae95c318e3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Jul 2024 15:47:35 +0200 Subject: [PATCH 27/29] Also clean the tempfile in case of ctrl-c --- sklearn/datasets/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index fb71daf029c09..62fa8210dab2b 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1520,7 +1520,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): f"The SHA256 checksum of {remote.filename} ({checksum}) " f"differs from expected ({remote.checksum})." ) - except Exception: + except (Exception, KeyboardInterrupt): os.unlink(temp_file.name) raise From bed85156cdec5b257ed03771fce7344d6d194073 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Jul 2024 14:12:54 +0200 Subject: [PATCH 28/29] Clarify the relationship between delete=False and temp_file.close() --- sklearn/datasets/_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 62fa8210dab2b..5e888c0c0f548 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1499,6 +1499,11 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): temp_file = NamedTemporaryFile( prefix=remote.filename + ".part_", dir=folder_path, delete=False ) + # Note that Python 3.12's `delete_on_close=True` is ignored as we set + # `delete=False` explicitly. So after this line the empty temporary file still + # exists on disk to make sure that it's uniquely reserved for this specific call of + # `_fetch_remote` and therefore it protects against any corruption by parallel + # calls. temp_file.close() try: temp_file_path = Path(temp_file.name) From 4fc5c10ec9d4ead05b28f1c4b3a8375d1781e601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 11 Jul 2024 14:40:54 +0200 Subject: [PATCH 29/29] lint --- sklearn/datasets/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index 5e888c0c0f548..62055d296402b 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -1499,7 +1499,7 @@ def _fetch_remote(remote, dirname=None, n_retries=3, delay=1): temp_file = NamedTemporaryFile( prefix=remote.filename + ".part_", dir=folder_path, delete=False ) - # Note that Python 3.12's `delete_on_close=True` is ignored as we set + # Note that Python 3.12's `delete_on_close=True` is ignored as we set # `delete=False` explicitly. So after this line the empty temporary file still # exists on disk to make sure that it's uniquely reserved for this specific call of # `_fetch_remote` and therefore it protects against any corruption by parallel