Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Use OS-specific cache directories for get_data_home and add tests #31438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
Loading
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 2 benchmarks/bench_tsne_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

import numpy as np
from joblib import Memory
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads

from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_array
from sklearn.utils import shuffle as _shuffle
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads

LOG_DIR = "mnist_tsne_output"
if not os.path.exists(LOG_DIR):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- New parameter ``return_X_y`` added to :func:`datasets.make_classification`. The
default value of the parameter does not change how the function behaves.
By :user:`Success Moses <SuccessMoses>` and :user:`Adam Cooper <arc12>`

3 changes: 3 additions & 0 deletions 3 doc/whats_new/upcoming_changes/sklearn.datasets/31267.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Updated 'get_data_home' to use OS-specific cache directories instead of home dir by
PR: 31267
By :user:Namit24
66 changes: 31 additions & 35 deletions 66 sklearn/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import gzip
import hashlib
import os
import platform
import re
import shutil
import time
Expand All @@ -17,8 +18,8 @@
from collections import namedtuple
from importlib import resources
from numbers import Integral
from os import environ, listdir, makedirs
from os.path import expanduser, isdir, join, splitext
from os import listdir, makedirs
from os.path import isdir, join, splitext
from pathlib import Path
from tempfile import NamedTemporaryFile
from urllib.error import URLError
Expand All @@ -45,44 +46,39 @@
},
prefer_skip_nested_validation=True,
)
def get_data_home(data_home=None) -> str:
"""Return the path of the scikit-learn data directory.
def get_data_home(data_home=None):
"""Return the path to scikit-learn data home cache folder.

This folder is used by some large dataset loaders to avoid downloading the
data several times.
By default, it uses OS-appropriate cache directories:
- Linux: $XDG_CACHE_HOME/scikit_learn_data or ~/.cache/scikit_learn_data
- macOS: ~/Library/Caches/scikit_learn_data
- Windows: %LOCALAPPDATA%/scikit_learn_data
"""
if data_home is None:
# Determine the base cache directory based on the operating system
system = platform.system()
if system == "Linux":
base_dir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
elif system == "Darwin": # macOS
base_dir = os.path.expanduser("~/Library/Caches")
elif system == "Windows":
base_dir = os.environ.get(
"LOCALAPPDATA", os.path.expanduser("~/AppData/Local")
)
else:
# Fallback for other operating systems
base_dir = os.path.expanduser("~/.cache")

By default the data directory is set to a folder named 'scikit_learn_data' in the
user home folder.
data_home = os.path.join(base_dir, "scikit_learn_data")

Alternatively, it can be set by the 'SCIKIT_LEARN_DATA' environment
variable or programmatically by giving an explicit folder path. The '~'
symbol is expanded to the user home folder.
# Override with environment variable if set
data_home = os.environ.get("SCIKIT_LEARN_DATA", data_home)

If the folder does not already exist, it is automatically created.
# Expand user path (handles ~ and environment variables)
data_home = os.path.expanduser(data_home)

Parameters
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/scikit_learn_data`.

Returns
-------
data_home: str
The path to scikit-learn data directory.

Examples
--------
>>> import os
>>> from sklearn.datasets import get_data_home
>>> data_home_path = get_data_home()
>>> os.path.exists(data_home_path)
True
"""
if data_home is None:
data_home = environ.get("SCIKIT_LEARN_DATA", join("~", "scikit_learn_data"))
data_home = expanduser(data_home)
makedirs(data_home, exist_ok=True)
# Create the directory if it doesn't exist
os.makedirs(data_home, exist_ok=True)
return data_home


Expand Down
30 changes: 30 additions & 0 deletions 30 sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import os
import pickle
import re
import warnings
from pathlib import Path

import numpy as np
import pytest
Expand All @@ -23,6 +25,7 @@
is_regressor,
)
from sklearn.cluster import KMeans
from sklearn.datasets import get_data_home
from sklearn.decomposition import PCA
from sklearn.ensemble import IsolationForest
from sklearn.exceptions import InconsistentVersionWarning
Expand Down Expand Up @@ -1000,3 +1003,30 @@ def test_get_params_html():

assert est._get_params_html() == {"l1": 0, "empty": "test"}
assert est._get_params_html().non_default == ("empty",)


def test_get_data_home_platforms(monkeypatch, tmp_path):
"""Test platform-specific cache directories."""

# Test Linux with XDG_CACHE_HOME
monkeypatch.setattr("platform.system", lambda: "Linux")
monkeypatch.setenv("XDG_CACHE_HOME", "/tmp/xdg_cache")
expected = Path("/tmp/xdg_cache/scikit_learn_data")
assert Path(get_data_home()) == expected

# Test Linux without XDG_CACHE_HOME
monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
expected = Path(os.path.expanduser("~/.cache/scikit_learn_data"))
assert Path(get_data_home()) == expected

# Test macOS
monkeypatch.setattr("platform.system", lambda: "Darwin")
expected = Path(os.path.expanduser("~/Library/Caches/scikit_learn_data"))
assert Path(get_data_home()) == expected

# Test Windows
monkeypatch.setattr("platform.system", lambda: "Windows")
monkeypatch.setenv("LOCALAPPDATA", str(tmp_path))
expected = tmp_path / "scikit_learn_data"
result = Path(get_data_home())
assert result == expected
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.