195 lines
6.0 KiB
Python
195 lines
6.0 KiB
Python
|
import os
|
||
|
import warnings
|
||
|
from os import environ
|
||
|
from os.path import exists, join
|
||
|
|
||
|
import pytest
|
||
|
from _pytest.doctest import DoctestItem
|
||
|
|
||
|
from sklearn.datasets import get_data_home
|
||
|
from sklearn.datasets._base import _pkl_filepath
|
||
|
from sklearn.datasets._twenty_newsgroups import CACHE_NAME
|
||
|
from sklearn.utils._testing import SkipTest, check_skip_network
|
||
|
from sklearn.utils.fixes import _IS_PYPY, np_base_version, parse_version
|
||
|
|
||
|
|
||
|
def setup_labeled_faces():
|
||
|
data_home = get_data_home()
|
||
|
if not exists(join(data_home, "lfw_home")):
|
||
|
raise SkipTest("Skipping dataset loading doctests")
|
||
|
|
||
|
|
||
|
def setup_rcv1():
|
||
|
check_skip_network()
|
||
|
# skip the test in rcv1.rst if the dataset is not already loaded
|
||
|
rcv1_dir = join(get_data_home(), "RCV1")
|
||
|
if not exists(rcv1_dir):
|
||
|
raise SkipTest("Download RCV1 dataset to run this test.")
|
||
|
|
||
|
|
||
|
def setup_twenty_newsgroups():
|
||
|
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
|
||
|
if not exists(cache_path):
|
||
|
raise SkipTest("Skipping dataset loading doctests")
|
||
|
|
||
|
|
||
|
def setup_working_with_text_data():
|
||
|
if _IS_PYPY and os.environ.get("CI", None):
|
||
|
raise SkipTest("Skipping too slow test with PyPy on CI")
|
||
|
check_skip_network()
|
||
|
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
|
||
|
if not exists(cache_path):
|
||
|
raise SkipTest("Skipping dataset loading doctests")
|
||
|
|
||
|
|
||
|
def setup_loading_other_datasets():
|
||
|
try:
|
||
|
import pandas # noqa
|
||
|
except ImportError:
|
||
|
raise SkipTest("Skipping loading_other_datasets.rst, pandas not installed")
|
||
|
|
||
|
# checks SKLEARN_SKIP_NETWORK_TESTS to see if test should run
|
||
|
run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
|
||
|
if not run_network_tests:
|
||
|
raise SkipTest(
|
||
|
"Skipping loading_other_datasets.rst, tests can be "
|
||
|
"enabled by setting SKLEARN_SKIP_NETWORK_TESTS=0"
|
||
|
)
|
||
|
|
||
|
|
||
|
def setup_compose():
|
||
|
try:
|
||
|
import pandas # noqa
|
||
|
except ImportError:
|
||
|
raise SkipTest("Skipping compose.rst, pandas not installed")
|
||
|
|
||
|
|
||
|
def setup_impute():
|
||
|
try:
|
||
|
import pandas # noqa
|
||
|
except ImportError:
|
||
|
raise SkipTest("Skipping impute.rst, pandas not installed")
|
||
|
|
||
|
|
||
|
def setup_grid_search():
|
||
|
try:
|
||
|
import pandas # noqa
|
||
|
except ImportError:
|
||
|
raise SkipTest("Skipping grid_search.rst, pandas not installed")
|
||
|
|
||
|
|
||
|
def setup_preprocessing():
|
||
|
try:
|
||
|
import pandas # noqa
|
||
|
|
||
|
if parse_version(pandas.__version__) < parse_version("1.1.0"):
|
||
|
raise SkipTest("Skipping preprocessing.rst, pandas version < 1.1.0")
|
||
|
except ImportError:
|
||
|
raise SkipTest("Skipping preprocessing.rst, pandas not installed")
|
||
|
|
||
|
|
||
|
def setup_unsupervised_learning():
|
||
|
try:
|
||
|
import skimage # noqa
|
||
|
except ImportError:
|
||
|
raise SkipTest("Skipping unsupervised_learning.rst, scikit-image not installed")
|
||
|
# ignore deprecation warnings from scipy.misc.face
|
||
|
warnings.filterwarnings(
|
||
|
"ignore", "The binary mode of fromstring", DeprecationWarning
|
||
|
)
|
||
|
|
||
|
|
||
|
def skip_if_matplotlib_not_installed(fname):
|
||
|
try:
|
||
|
import matplotlib # noqa
|
||
|
except ImportError:
|
||
|
basename = os.path.basename(fname)
|
||
|
raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")
|
||
|
|
||
|
|
||
|
def skip_if_cupy_not_installed(fname):
|
||
|
try:
|
||
|
import cupy # noqa
|
||
|
except ImportError:
|
||
|
basename = os.path.basename(fname)
|
||
|
raise SkipTest(f"Skipping doctests for {basename}, cupy not installed")
|
||
|
|
||
|
|
||
|
def pytest_runtest_setup(item):
|
||
|
fname = item.fspath.strpath
|
||
|
# normalize filename to use forward slashes on Windows for easier handling
|
||
|
# later
|
||
|
fname = fname.replace(os.sep, "/")
|
||
|
|
||
|
is_index = fname.endswith("datasets/index.rst")
|
||
|
if fname.endswith("datasets/labeled_faces.rst") or is_index:
|
||
|
setup_labeled_faces()
|
||
|
elif fname.endswith("datasets/rcv1.rst") or is_index:
|
||
|
setup_rcv1()
|
||
|
elif fname.endswith("datasets/twenty_newsgroups.rst") or is_index:
|
||
|
setup_twenty_newsgroups()
|
||
|
elif fname.endswith("modules/compose.rst") or is_index:
|
||
|
setup_compose()
|
||
|
elif fname.endswith("datasets/loading_other_datasets.rst"):
|
||
|
setup_loading_other_datasets()
|
||
|
elif fname.endswith("modules/impute.rst"):
|
||
|
setup_impute()
|
||
|
elif fname.endswith("modules/grid_search.rst"):
|
||
|
setup_grid_search()
|
||
|
elif fname.endswith("modules/preprocessing.rst"):
|
||
|
setup_preprocessing()
|
||
|
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
|
||
|
setup_unsupervised_learning()
|
||
|
|
||
|
rst_files_requiring_matplotlib = [
|
||
|
"modules/partial_dependence.rst",
|
||
|
"modules/tree.rst",
|
||
|
]
|
||
|
for each in rst_files_requiring_matplotlib:
|
||
|
if fname.endswith(each):
|
||
|
skip_if_matplotlib_not_installed(fname)
|
||
|
|
||
|
if fname.endswith("array_api.rst"):
|
||
|
skip_if_cupy_not_installed(fname)
|
||
|
|
||
|
|
||
|
def pytest_configure(config):
|
||
|
# Use matplotlib agg backend during the tests including doctests
|
||
|
try:
|
||
|
import matplotlib
|
||
|
|
||
|
matplotlib.use("agg")
|
||
|
except ImportError:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def pytest_collection_modifyitems(config, items):
|
||
|
"""Called after collect is completed.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
config : pytest config
|
||
|
items : list of collected items
|
||
|
"""
|
||
|
skip_doctests = False
|
||
|
if np_base_version >= parse_version("2"):
|
||
|
# Skip doctests when using numpy 2 for now. See the following discussion
|
||
|
# to decide what to do in the longer term:
|
||
|
# https://github.com/scikit-learn/scikit-learn/issues/27339
|
||
|
reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2"
|
||
|
skip_doctests = True
|
||
|
|
||
|
# Normally doctest has the entire module's scope. Here we set globs to an empty dict
|
||
|
# to remove the module's scope:
|
||
|
# https://docs.python.org/3/library/doctest.html#what-s-the-execution-context
|
||
|
for item in items:
|
||
|
if isinstance(item, DoctestItem):
|
||
|
item.dtest.globs = {}
|
||
|
|
||
|
if skip_doctests:
|
||
|
skip_marker = pytest.mark.skip(reason=reason)
|
||
|
|
||
|
for item in items:
|
||
|
if isinstance(item, DoctestItem):
|
||
|
item.add_marker(skip_marker)
|