init
This commit is contained in:
@@ -0,0 +1,639 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
import operator
|
||||
import os
|
||||
from sys import byteorder
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
ContextManager,
|
||||
cast,
|
||||
)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas._config.localization import (
|
||||
can_set_locale,
|
||||
get_locales,
|
||||
set_locale,
|
||||
)
|
||||
|
||||
from pandas.compat import pa_version_under10p1
|
||||
|
||||
from pandas.core.dtypes.common import is_string_dtype
|
||||
|
||||
import pandas as pd
|
||||
from pandas import (
|
||||
ArrowDtype,
|
||||
DataFrame,
|
||||
Index,
|
||||
MultiIndex,
|
||||
RangeIndex,
|
||||
Series,
|
||||
)
|
||||
from pandas._testing._io import (
|
||||
round_trip_localpath,
|
||||
round_trip_pathlib,
|
||||
round_trip_pickle,
|
||||
write_to_compressed,
|
||||
)
|
||||
from pandas._testing._warnings import (
|
||||
assert_produces_warning,
|
||||
maybe_produces_warning,
|
||||
)
|
||||
from pandas._testing.asserters import (
|
||||
assert_almost_equal,
|
||||
assert_attr_equal,
|
||||
assert_categorical_equal,
|
||||
assert_class_equal,
|
||||
assert_contains_all,
|
||||
assert_copy,
|
||||
assert_datetime_array_equal,
|
||||
assert_dict_equal,
|
||||
assert_equal,
|
||||
assert_extension_array_equal,
|
||||
assert_frame_equal,
|
||||
assert_index_equal,
|
||||
assert_indexing_slices_equivalent,
|
||||
assert_interval_array_equal,
|
||||
assert_is_sorted,
|
||||
assert_is_valid_plot_return_object,
|
||||
assert_metadata_equivalent,
|
||||
assert_numpy_array_equal,
|
||||
assert_period_array_equal,
|
||||
assert_series_equal,
|
||||
assert_sp_array_equal,
|
||||
assert_timedelta_array_equal,
|
||||
raise_assert_detail,
|
||||
)
|
||||
from pandas._testing.compat import (
|
||||
get_dtype,
|
||||
get_obj,
|
||||
)
|
||||
from pandas._testing.contexts import (
|
||||
assert_cow_warning,
|
||||
decompress_file,
|
||||
ensure_clean,
|
||||
raises_chained_assignment_error,
|
||||
set_timezone,
|
||||
use_numexpr,
|
||||
with_csv_dialect,
|
||||
)
|
||||
from pandas.core.arrays import (
|
||||
BaseMaskedArray,
|
||||
ExtensionArray,
|
||||
NumpyExtensionArray,
|
||||
)
|
||||
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
|
||||
from pandas.core.construction import extract_array
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas._typing import (
|
||||
Dtype,
|
||||
NpDtype,
|
||||
)
|
||||
|
||||
from pandas.core.arrays import ArrowExtensionArray
|
||||
|
||||
UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
|
||||
UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
|
||||
SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
|
||||
SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
|
||||
ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
|
||||
ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
|
||||
ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
|
||||
|
||||
FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
|
||||
FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
|
||||
ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
|
||||
|
||||
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
|
||||
STRING_DTYPES: list[Dtype] = [str, "str", "U"]
|
||||
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]
|
||||
|
||||
DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
|
||||
TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
|
||||
|
||||
BOOL_DTYPES: list[Dtype] = [bool, "bool"]
|
||||
BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
|
||||
OBJECT_DTYPES: list[Dtype] = [object, "object"]
|
||||
|
||||
ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
|
||||
ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
|
||||
ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
|
||||
ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
|
||||
|
||||
ALL_NUMPY_DTYPES = (
|
||||
ALL_REAL_NUMPY_DTYPES
|
||||
+ COMPLEX_DTYPES
|
||||
+ STRING_DTYPES
|
||||
+ DATETIME64_DTYPES
|
||||
+ TIMEDELTA64_DTYPES
|
||||
+ BOOL_DTYPES
|
||||
+ OBJECT_DTYPES
|
||||
+ BYTES_DTYPES
|
||||
)
|
||||
|
||||
NARROW_NP_DTYPES = [
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.int8,
|
||||
np.int16,
|
||||
np.int32,
|
||||
np.uint8,
|
||||
np.uint16,
|
||||
np.uint32,
|
||||
]
|
||||
|
||||
PYTHON_DATA_TYPES = [
|
||||
str,
|
||||
int,
|
||||
float,
|
||||
complex,
|
||||
list,
|
||||
tuple,
|
||||
range,
|
||||
dict,
|
||||
set,
|
||||
frozenset,
|
||||
bool,
|
||||
bytes,
|
||||
bytearray,
|
||||
memoryview,
|
||||
]
|
||||
|
||||
ENDIAN = {"little": "<", "big": ">"}[byteorder]
|
||||
|
||||
NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
|
||||
NP_NAT_OBJECTS = [
|
||||
cls("NaT", unit)
|
||||
for cls in [np.datetime64, np.timedelta64]
|
||||
for unit in [
|
||||
"Y",
|
||||
"M",
|
||||
"W",
|
||||
"D",
|
||||
"h",
|
||||
"m",
|
||||
"s",
|
||||
"ms",
|
||||
"us",
|
||||
"ns",
|
||||
"ps",
|
||||
"fs",
|
||||
"as",
|
||||
]
|
||||
]
|
||||
|
||||
if not pa_version_under10p1:
|
||||
import pyarrow as pa
|
||||
|
||||
UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
|
||||
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
|
||||
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
|
||||
ALL_INT_PYARROW_DTYPES_STR_REPR = [
|
||||
str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
|
||||
]
|
||||
|
||||
# pa.float16 doesn't seem supported
|
||||
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
|
||||
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
|
||||
FLOAT_PYARROW_DTYPES_STR_REPR = [
|
||||
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
|
||||
]
|
||||
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
|
||||
STRING_PYARROW_DTYPES = [pa.string()]
|
||||
BINARY_PYARROW_DTYPES = [pa.binary()]
|
||||
|
||||
TIME_PYARROW_DTYPES = [
|
||||
pa.time32("s"),
|
||||
pa.time32("ms"),
|
||||
pa.time64("us"),
|
||||
pa.time64("ns"),
|
||||
]
|
||||
DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
|
||||
DATETIME_PYARROW_DTYPES = [
|
||||
pa.timestamp(unit=unit, tz=tz)
|
||||
for unit in ["s", "ms", "us", "ns"]
|
||||
for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
|
||||
]
|
||||
TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
|
||||
|
||||
BOOL_PYARROW_DTYPES = [pa.bool_()]
|
||||
|
||||
# TODO: Add container like pyarrow types:
|
||||
# https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
|
||||
ALL_PYARROW_DTYPES = (
|
||||
ALL_INT_PYARROW_DTYPES
|
||||
+ FLOAT_PYARROW_DTYPES
|
||||
+ DECIMAL_PYARROW_DTYPES
|
||||
+ STRING_PYARROW_DTYPES
|
||||
+ BINARY_PYARROW_DTYPES
|
||||
+ TIME_PYARROW_DTYPES
|
||||
+ DATE_PYARROW_DTYPES
|
||||
+ DATETIME_PYARROW_DTYPES
|
||||
+ TIMEDELTA_PYARROW_DTYPES
|
||||
+ BOOL_PYARROW_DTYPES
|
||||
)
|
||||
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
|
||||
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
|
||||
)
|
||||
else:
|
||||
FLOAT_PYARROW_DTYPES_STR_REPR = []
|
||||
ALL_INT_PYARROW_DTYPES_STR_REPR = []
|
||||
ALL_PYARROW_DTYPES = []
|
||||
ALL_REAL_PYARROW_DTYPES_STR_REPR = []
|
||||
|
||||
ALL_REAL_NULLABLE_DTYPES = (
|
||||
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
|
||||
)
|
||||
|
||||
arithmetic_dunder_methods = [
|
||||
"__add__",
|
||||
"__radd__",
|
||||
"__sub__",
|
||||
"__rsub__",
|
||||
"__mul__",
|
||||
"__rmul__",
|
||||
"__floordiv__",
|
||||
"__rfloordiv__",
|
||||
"__truediv__",
|
||||
"__rtruediv__",
|
||||
"__pow__",
|
||||
"__rpow__",
|
||||
"__mod__",
|
||||
"__rmod__",
|
||||
]
|
||||
|
||||
comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Comparators
|
||||
|
||||
|
||||
def box_expected(expected, box_cls, transpose: bool = True):
|
||||
"""
|
||||
Helper function to wrap the expected output of a test in a given box_class.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expected : np.ndarray, Index, Series
|
||||
box_cls : {Index, Series, DataFrame}
|
||||
|
||||
Returns
|
||||
-------
|
||||
subclass of box_cls
|
||||
"""
|
||||
if box_cls is pd.array:
|
||||
if isinstance(expected, RangeIndex):
|
||||
# pd.array would return an IntegerArray
|
||||
expected = NumpyExtensionArray(np.asarray(expected._values))
|
||||
else:
|
||||
expected = pd.array(expected, copy=False)
|
||||
elif box_cls is Index:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
|
||||
expected = Index(expected)
|
||||
elif box_cls is Series:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
|
||||
expected = Series(expected)
|
||||
elif box_cls is DataFrame:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
|
||||
expected = Series(expected).to_frame()
|
||||
if transpose:
|
||||
# for vector operations, we need a DataFrame to be a single-row,
|
||||
# not a single-column, in order to operate against non-DataFrame
|
||||
# vectors of the same length. But convert to two rows to avoid
|
||||
# single-row special cases in datetime arithmetic
|
||||
expected = expected.T
|
||||
expected = pd.concat([expected] * 2, ignore_index=True)
|
||||
elif box_cls is np.ndarray or box_cls is np.array:
|
||||
expected = np.array(expected)
|
||||
elif box_cls is to_array:
|
||||
expected = to_array(expected)
|
||||
else:
|
||||
raise NotImplementedError(box_cls)
|
||||
return expected
|
||||
|
||||
|
||||
def to_array(obj):
|
||||
"""
|
||||
Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
|
||||
"""
|
||||
# temporary implementation until we get pd.array in place
|
||||
dtype = getattr(obj, "dtype", None)
|
||||
|
||||
if dtype is None:
|
||||
return np.asarray(obj)
|
||||
|
||||
return extract_array(obj, extract_numpy=True)
|
||||
|
||||
|
||||
class SubclassedSeries(Series):
|
||||
_metadata = ["testattr", "name"]
|
||||
|
||||
@property
|
||||
def _constructor(self):
|
||||
# For testing, those properties return a generic callable, and not
|
||||
# the actual class. In this case that is equivalent, but it is to
|
||||
# ensure we don't rely on the property returning a class
|
||||
# See https://github.com/pandas-dev/pandas/pull/46018 and
|
||||
# https://github.com/pandas-dev/pandas/issues/32638 and linked issues
|
||||
return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _constructor_expanddim(self):
|
||||
return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
|
||||
|
||||
|
||||
class SubclassedDataFrame(DataFrame):
|
||||
_metadata = ["testattr"]
|
||||
|
||||
@property
|
||||
def _constructor(self):
|
||||
return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _constructor_sliced(self):
|
||||
return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
|
||||
|
||||
|
||||
def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
|
||||
"""
|
||||
Convert list of CSV rows to single CSV-formatted string for current OS.
|
||||
|
||||
This method is used for creating expected value of to_csv() method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rows_list : List[str]
|
||||
Each element represents the row of csv.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Expected output of to_csv() in current OS.
|
||||
"""
|
||||
sep = os.linesep
|
||||
return sep.join(rows_list) + sep
|
||||
|
||||
|
||||
def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
|
||||
"""
|
||||
Helper function to mark pytest.raises that have an external error message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expected_exception : Exception
|
||||
Expected error to raise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
Regular `pytest.raises` function with `match` equal to `None`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
return pytest.raises(expected_exception, match=None)
|
||||
|
||||
|
||||
cython_table = pd.core.common._cython_table.items()
|
||||
|
||||
|
||||
def get_cython_table_params(ndframe, func_names_and_expected):
|
||||
"""
|
||||
Combine frame, functions from com._cython_table
|
||||
keys and expected result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ndframe : DataFrame or Series
|
||||
func_names_and_expected : Sequence of two items
|
||||
The first item is a name of a NDFrame method ('sum', 'prod') etc.
|
||||
The second item is the expected return value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
List of three items (DataFrame, function, expected result)
|
||||
"""
|
||||
results = []
|
||||
for func_name, expected in func_names_and_expected:
|
||||
results.append((ndframe, func_name, expected))
|
||||
results += [
|
||||
(ndframe, func, expected)
|
||||
for func, name in cython_table
|
||||
if name == func_name
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def get_op_from_name(op_name: str) -> Callable:
|
||||
"""
|
||||
The operator function for a given op name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_name : str
|
||||
The op name, in form of "add" or "__add__".
|
||||
|
||||
Returns
|
||||
-------
|
||||
function
|
||||
A function performing the operation.
|
||||
"""
|
||||
short_opname = op_name.strip("_")
|
||||
try:
|
||||
op = getattr(operator, short_opname)
|
||||
except AttributeError:
|
||||
# Assume it is the reverse operator
|
||||
rop = getattr(operator, short_opname[1:])
|
||||
op = lambda x, y: rop(y, x)
|
||||
|
||||
return op
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Indexing test helpers
|
||||
|
||||
|
||||
def getitem(x):
|
||||
return x
|
||||
|
||||
|
||||
def setitem(x):
|
||||
return x
|
||||
|
||||
|
||||
def loc(x):
|
||||
return x.loc
|
||||
|
||||
|
||||
def iloc(x):
|
||||
return x.iloc
|
||||
|
||||
|
||||
def at(x):
|
||||
return x.at
|
||||
|
||||
|
||||
def iat(x):
|
||||
return x.iat
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_UNITS = ["s", "ms", "us", "ns"]
|
||||
|
||||
|
||||
def get_finest_unit(left: str, right: str):
|
||||
"""
|
||||
Find the higher of two datetime64 units.
|
||||
"""
|
||||
if _UNITS.index(left) >= _UNITS.index(right):
|
||||
return left
|
||||
return right
|
||||
|
||||
|
||||
def shares_memory(left, right) -> bool:
|
||||
"""
|
||||
Pandas-compat for np.shares_memory.
|
||||
"""
|
||||
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||
return np.shares_memory(left, right)
|
||||
elif isinstance(left, np.ndarray):
|
||||
# Call with reversed args to get to unpacking logic below.
|
||||
return shares_memory(right, left)
|
||||
|
||||
if isinstance(left, RangeIndex):
|
||||
return False
|
||||
if isinstance(left, MultiIndex):
|
||||
return shares_memory(left._codes, right)
|
||||
if isinstance(left, (Index, Series)):
|
||||
return shares_memory(left._values, right)
|
||||
|
||||
if isinstance(left, NDArrayBackedExtensionArray):
|
||||
return shares_memory(left._ndarray, right)
|
||||
if isinstance(left, pd.core.arrays.SparseArray):
|
||||
return shares_memory(left.sp_values, right)
|
||||
if isinstance(left, pd.core.arrays.IntervalArray):
|
||||
return shares_memory(left._left, right) or shares_memory(left._right, right)
|
||||
|
||||
if (
|
||||
isinstance(left, ExtensionArray)
|
||||
and is_string_dtype(left.dtype)
|
||||
and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
|
||||
):
|
||||
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
|
||||
left = cast("ArrowExtensionArray", left)
|
||||
if (
|
||||
isinstance(right, ExtensionArray)
|
||||
and is_string_dtype(right.dtype)
|
||||
and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
|
||||
):
|
||||
right = cast("ArrowExtensionArray", right)
|
||||
left_pa_data = left._pa_array
|
||||
right_pa_data = right._pa_array
|
||||
left_buf1 = left_pa_data.chunk(0).buffers()[1]
|
||||
right_buf1 = right_pa_data.chunk(0).buffers()[1]
|
||||
return left_buf1 == right_buf1
|
||||
|
||||
if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
|
||||
# By convention, we'll say these share memory if they share *either*
|
||||
# the _data or the _mask
|
||||
return np.shares_memory(left._data, right._data) or np.shares_memory(
|
||||
left._mask, right._mask
|
||||
)
|
||||
|
||||
if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
|
||||
arr = left._mgr.arrays[0]
|
||||
return shares_memory(arr, right)
|
||||
|
||||
raise NotImplementedError(type(left), type(right))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ALL_INT_EA_DTYPES",
|
||||
"ALL_INT_NUMPY_DTYPES",
|
||||
"ALL_NUMPY_DTYPES",
|
||||
"ALL_REAL_NUMPY_DTYPES",
|
||||
"assert_almost_equal",
|
||||
"assert_attr_equal",
|
||||
"assert_categorical_equal",
|
||||
"assert_class_equal",
|
||||
"assert_contains_all",
|
||||
"assert_copy",
|
||||
"assert_datetime_array_equal",
|
||||
"assert_dict_equal",
|
||||
"assert_equal",
|
||||
"assert_extension_array_equal",
|
||||
"assert_frame_equal",
|
||||
"assert_index_equal",
|
||||
"assert_indexing_slices_equivalent",
|
||||
"assert_interval_array_equal",
|
||||
"assert_is_sorted",
|
||||
"assert_is_valid_plot_return_object",
|
||||
"assert_metadata_equivalent",
|
||||
"assert_numpy_array_equal",
|
||||
"assert_period_array_equal",
|
||||
"assert_produces_warning",
|
||||
"assert_series_equal",
|
||||
"assert_sp_array_equal",
|
||||
"assert_timedelta_array_equal",
|
||||
"assert_cow_warning",
|
||||
"at",
|
||||
"BOOL_DTYPES",
|
||||
"box_expected",
|
||||
"BYTES_DTYPES",
|
||||
"can_set_locale",
|
||||
"COMPLEX_DTYPES",
|
||||
"convert_rows_list_to_csv_str",
|
||||
"DATETIME64_DTYPES",
|
||||
"decompress_file",
|
||||
"ENDIAN",
|
||||
"ensure_clean",
|
||||
"external_error_raised",
|
||||
"FLOAT_EA_DTYPES",
|
||||
"FLOAT_NUMPY_DTYPES",
|
||||
"get_cython_table_params",
|
||||
"get_dtype",
|
||||
"getitem",
|
||||
"get_locales",
|
||||
"get_finest_unit",
|
||||
"get_obj",
|
||||
"get_op_from_name",
|
||||
"iat",
|
||||
"iloc",
|
||||
"loc",
|
||||
"maybe_produces_warning",
|
||||
"NARROW_NP_DTYPES",
|
||||
"NP_NAT_OBJECTS",
|
||||
"NULL_OBJECTS",
|
||||
"OBJECT_DTYPES",
|
||||
"raise_assert_detail",
|
||||
"raises_chained_assignment_error",
|
||||
"round_trip_localpath",
|
||||
"round_trip_pathlib",
|
||||
"round_trip_pickle",
|
||||
"setitem",
|
||||
"set_locale",
|
||||
"set_timezone",
|
||||
"shares_memory",
|
||||
"SIGNED_INT_EA_DTYPES",
|
||||
"SIGNED_INT_NUMPY_DTYPES",
|
||||
"STRING_DTYPES",
|
||||
"SubclassedDataFrame",
|
||||
"SubclassedSeries",
|
||||
"TIMEDELTA64_DTYPES",
|
||||
"to_array",
|
||||
"UNSIGNED_INT_EA_DTYPES",
|
||||
"UNSIGNED_INT_NUMPY_DTYPES",
|
||||
"use_numexpr",
|
||||
"with_csv_dialect",
|
||||
"write_to_compressed",
|
||||
]
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Hypothesis data generator helpers.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from hypothesis import strategies as st
|
||||
from hypothesis.extra.dateutil import timezones as dateutil_timezones
|
||||
from hypothesis.extra.pytz import timezones as pytz_timezones
|
||||
|
||||
from pandas.compat import is_platform_windows
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pandas.tseries.offsets import (
|
||||
BMonthBegin,
|
||||
BMonthEnd,
|
||||
BQuarterBegin,
|
||||
BQuarterEnd,
|
||||
BYearBegin,
|
||||
BYearEnd,
|
||||
MonthBegin,
|
||||
MonthEnd,
|
||||
QuarterBegin,
|
||||
QuarterEnd,
|
||||
YearBegin,
|
||||
YearEnd,
|
||||
)
|
||||
|
||||
OPTIONAL_INTS = st.lists(st.one_of(st.integers(), st.none()), max_size=10, min_size=3)
|
||||
|
||||
OPTIONAL_FLOATS = st.lists(st.one_of(st.floats(), st.none()), max_size=10, min_size=3)
|
||||
|
||||
OPTIONAL_TEXT = st.lists(st.one_of(st.none(), st.text()), max_size=10, min_size=3)
|
||||
|
||||
OPTIONAL_DICTS = st.lists(
|
||||
st.one_of(st.none(), st.dictionaries(st.text(), st.integers())),
|
||||
max_size=10,
|
||||
min_size=3,
|
||||
)
|
||||
|
||||
OPTIONAL_LISTS = st.lists(
|
||||
st.one_of(st.none(), st.lists(st.text(), max_size=10, min_size=3)),
|
||||
max_size=10,
|
||||
min_size=3,
|
||||
)
|
||||
|
||||
OPTIONAL_ONE_OF_ALL = st.one_of(
|
||||
OPTIONAL_DICTS, OPTIONAL_FLOATS, OPTIONAL_INTS, OPTIONAL_LISTS, OPTIONAL_TEXT
|
||||
)
|
||||
|
||||
if is_platform_windows():
|
||||
DATETIME_NO_TZ = st.datetimes(min_value=datetime(1900, 1, 1))
|
||||
else:
|
||||
DATETIME_NO_TZ = st.datetimes()
|
||||
|
||||
DATETIME_JAN_1_1900_OPTIONAL_TZ = st.datetimes(
|
||||
min_value=pd.Timestamp(
|
||||
1900, 1, 1
|
||||
).to_pydatetime(), # pyright: ignore[reportGeneralTypeIssues]
|
||||
max_value=pd.Timestamp(
|
||||
1900, 1, 1
|
||||
).to_pydatetime(), # pyright: ignore[reportGeneralTypeIssues]
|
||||
timezones=st.one_of(st.none(), dateutil_timezones(), pytz_timezones()),
|
||||
)
|
||||
|
||||
DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ = st.datetimes(
|
||||
min_value=pd.Timestamp.min.to_pydatetime(warn=False),
|
||||
max_value=pd.Timestamp.max.to_pydatetime(warn=False),
|
||||
)
|
||||
|
||||
INT_NEG_999_TO_POS_999 = st.integers(-999, 999)
|
||||
|
||||
# The strategy for each type is registered in conftest.py, as they don't carry
|
||||
# enough runtime information (e.g. type hints) to infer how to build them.
|
||||
YQM_OFFSET = st.one_of(
|
||||
*map(
|
||||
st.from_type,
|
||||
[
|
||||
MonthBegin,
|
||||
MonthEnd,
|
||||
BMonthBegin,
|
||||
BMonthEnd,
|
||||
QuarterBegin,
|
||||
QuarterEnd,
|
||||
BQuarterBegin,
|
||||
BQuarterEnd,
|
||||
YearBegin,
|
||||
YearEnd,
|
||||
BYearBegin,
|
||||
BYearEnd,
|
||||
],
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import io
|
||||
import pathlib
|
||||
import tarfile
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
import uuid
|
||||
import zipfile
|
||||
|
||||
from pandas.compat import (
|
||||
get_bz2_file,
|
||||
get_lzma_file,
|
||||
)
|
||||
from pandas.compat._optional import import_optional_dependency
|
||||
|
||||
import pandas as pd
|
||||
from pandas._testing.contexts import ensure_clean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas._typing import (
|
||||
FilePath,
|
||||
ReadPickleBuffer,
|
||||
)
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Series,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File-IO
|
||||
|
||||
|
||||
def round_trip_pickle(
|
||||
obj: Any, path: FilePath | ReadPickleBuffer | None = None
|
||||
) -> DataFrame | Series:
|
||||
"""
|
||||
Pickle an object and then read it again.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : any object
|
||||
The object to pickle and then re-read.
|
||||
path : str, path object or file-like object, default None
|
||||
The path where the pickled object is written and then read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas object
|
||||
The original object that was pickled and then re-read.
|
||||
"""
|
||||
_path = path
|
||||
if _path is None:
|
||||
_path = f"__{uuid.uuid4()}__.pickle"
|
||||
with ensure_clean(_path) as temp_path:
|
||||
pd.to_pickle(obj, temp_path)
|
||||
return pd.read_pickle(temp_path)
|
||||
|
||||
|
||||
def round_trip_pathlib(writer, reader, path: str | None = None):
|
||||
"""
|
||||
Write an object to file specified by a pathlib.Path and read it back
|
||||
|
||||
Parameters
|
||||
----------
|
||||
writer : callable bound to pandas object
|
||||
IO writing function (e.g. DataFrame.to_csv )
|
||||
reader : callable
|
||||
IO reading function (e.g. pd.read_csv )
|
||||
path : str, default None
|
||||
The path where the object is written and then read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas object
|
||||
The original object that was serialized and then re-read.
|
||||
"""
|
||||
Path = pathlib.Path
|
||||
if path is None:
|
||||
path = "___pathlib___"
|
||||
with ensure_clean(path) as path:
|
||||
writer(Path(path)) # type: ignore[arg-type]
|
||||
obj = reader(Path(path)) # type: ignore[arg-type]
|
||||
return obj
|
||||
|
||||
|
||||
def round_trip_localpath(writer, reader, path: str | None = None):
|
||||
"""
|
||||
Write an object to file specified by a py.path LocalPath and read it back.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
writer : callable bound to pandas object
|
||||
IO writing function (e.g. DataFrame.to_csv )
|
||||
reader : callable
|
||||
IO reading function (e.g. pd.read_csv )
|
||||
path : str, default None
|
||||
The path where the object is written and then read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas object
|
||||
The original object that was serialized and then re-read.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
LocalPath = pytest.importorskip("py.path").local
|
||||
if path is None:
|
||||
path = "___localpath___"
|
||||
with ensure_clean(path) as path:
|
||||
writer(LocalPath(path))
|
||||
obj = reader(LocalPath(path))
|
||||
return obj
|
||||
|
||||
|
||||
def write_to_compressed(compression, path, data, dest: str = "test") -> None:
|
||||
"""
|
||||
Write data to a compressed file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
|
||||
The compression type to use.
|
||||
path : str
|
||||
The file path to write the data.
|
||||
data : str
|
||||
The data to write.
|
||||
dest : str, default "test"
|
||||
The destination file (for ZIP only)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError : An invalid compression value was passed in.
|
||||
"""
|
||||
args: tuple[Any, ...] = (data,)
|
||||
mode = "wb"
|
||||
method = "write"
|
||||
compress_method: Callable
|
||||
|
||||
if compression == "zip":
|
||||
compress_method = zipfile.ZipFile
|
||||
mode = "w"
|
||||
args = (dest, data)
|
||||
method = "writestr"
|
||||
elif compression == "tar":
|
||||
compress_method = tarfile.TarFile
|
||||
mode = "w"
|
||||
file = tarfile.TarInfo(name=dest)
|
||||
bytes = io.BytesIO(data)
|
||||
file.size = len(data)
|
||||
args = (file, bytes)
|
||||
method = "addfile"
|
||||
elif compression == "gzip":
|
||||
compress_method = gzip.GzipFile
|
||||
elif compression == "bz2":
|
||||
compress_method = get_bz2_file()
|
||||
elif compression == "zstd":
|
||||
compress_method = import_optional_dependency("zstandard").open
|
||||
elif compression == "xz":
|
||||
compress_method = get_lzma_file()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized compression type: {compression}")
|
||||
|
||||
with compress_method(path, mode=mode) as f:
|
||||
getattr(f, method)(*args)
|
||||
@@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import (
|
||||
contextmanager,
|
||||
nullcontext,
|
||||
)
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
import warnings
|
||||
|
||||
from pandas.compat import PY311
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
Generator,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_produces_warning(
|
||||
expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning,
|
||||
filter_level: Literal[
|
||||
"error", "ignore", "always", "default", "module", "once"
|
||||
] = "always",
|
||||
check_stacklevel: bool = True,
|
||||
raise_on_extra_warnings: bool = True,
|
||||
match: str | None = None,
|
||||
) -> Generator[list[warnings.WarningMessage], None, None]:
|
||||
"""
|
||||
Context manager for running code expected to either raise a specific warning,
|
||||
multiple specific warnings, or not raise any warnings. Verifies that the code
|
||||
raises the expected warning(s), and that it does not raise any other unexpected
|
||||
warnings. It is basically a wrapper around ``warnings.catch_warnings``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning
|
||||
The type of Exception raised. ``exception.Warning`` is the base
|
||||
class for all warnings. To raise multiple types of exceptions,
|
||||
pass them as a tuple. To check that no warning is returned,
|
||||
specify ``False`` or ``None``.
|
||||
filter_level : str or None, default "always"
|
||||
Specifies whether warnings are ignored, displayed, or turned
|
||||
into errors.
|
||||
Valid values are:
|
||||
|
||||
* "error" - turns matching warnings into exceptions
|
||||
* "ignore" - discard the warning
|
||||
* "always" - always emit a warning
|
||||
* "default" - print the warning the first time it is generated
|
||||
from each location
|
||||
* "module" - print the warning the first time it is generated
|
||||
from each module
|
||||
* "once" - print the warning the first time it is generated
|
||||
|
||||
check_stacklevel : bool, default True
|
||||
If True, displays the line that called the function containing
|
||||
the warning to show were the function is called. Otherwise, the
|
||||
line that implements the function is displayed.
|
||||
raise_on_extra_warnings : bool, default True
|
||||
Whether extra warnings not of the type `expected_warning` should
|
||||
cause the test to fail.
|
||||
match : str, optional
|
||||
Match warning message.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import warnings
|
||||
>>> with assert_produces_warning():
|
||||
... warnings.warn(UserWarning())
|
||||
...
|
||||
>>> with assert_produces_warning(False):
|
||||
... warnings.warn(RuntimeWarning())
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
|
||||
>>> with assert_produces_warning(UserWarning):
|
||||
... warnings.warn(RuntimeWarning())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Did not see expected warning of class 'UserWarning'.
|
||||
|
||||
..warn:: This is *not* thread-safe.
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter(filter_level)
|
||||
try:
|
||||
yield w
|
||||
finally:
|
||||
if expected_warning:
|
||||
expected_warning = cast(type[Warning], expected_warning)
|
||||
_assert_caught_expected_warning(
|
||||
caught_warnings=w,
|
||||
expected_warning=expected_warning,
|
||||
match=match,
|
||||
check_stacklevel=check_stacklevel,
|
||||
)
|
||||
if raise_on_extra_warnings:
|
||||
_assert_caught_no_extra_warnings(
|
||||
caught_warnings=w,
|
||||
expected_warning=expected_warning,
|
||||
)
|
||||
|
||||
|
||||
def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs):
|
||||
"""
|
||||
Return a context manager that possibly checks a warning based on the condition
|
||||
"""
|
||||
if condition:
|
||||
return assert_produces_warning(warning, **kwargs)
|
||||
else:
|
||||
return nullcontext()
|
||||
|
||||
|
||||
def _assert_caught_expected_warning(
|
||||
*,
|
||||
caught_warnings: Sequence[warnings.WarningMessage],
|
||||
expected_warning: type[Warning],
|
||||
match: str | None,
|
||||
check_stacklevel: bool,
|
||||
) -> None:
|
||||
"""Assert that there was the expected warning among the caught warnings."""
|
||||
saw_warning = False
|
||||
matched_message = False
|
||||
unmatched_messages = []
|
||||
|
||||
for actual_warning in caught_warnings:
|
||||
if issubclass(actual_warning.category, expected_warning):
|
||||
saw_warning = True
|
||||
|
||||
if check_stacklevel:
|
||||
_assert_raised_with_correct_stacklevel(actual_warning)
|
||||
|
||||
if match is not None:
|
||||
if re.search(match, str(actual_warning.message)):
|
||||
matched_message = True
|
||||
else:
|
||||
unmatched_messages.append(actual_warning.message)
|
||||
|
||||
if not saw_warning:
|
||||
raise AssertionError(
|
||||
f"Did not see expected warning of class "
|
||||
f"{repr(expected_warning.__name__)}"
|
||||
)
|
||||
|
||||
if match and not matched_message:
|
||||
raise AssertionError(
|
||||
f"Did not see warning {repr(expected_warning.__name__)} "
|
||||
f"matching '{match}'. The emitted warning messages are "
|
||||
f"{unmatched_messages}"
|
||||
)
|
||||
|
||||
|
||||
def _assert_caught_no_extra_warnings(
|
||||
*,
|
||||
caught_warnings: Sequence[warnings.WarningMessage],
|
||||
expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
|
||||
) -> None:
|
||||
"""Assert that no extra warnings apart from the expected ones are caught."""
|
||||
extra_warnings = []
|
||||
|
||||
for actual_warning in caught_warnings:
|
||||
if _is_unexpected_warning(actual_warning, expected_warning):
|
||||
# GH#38630 pytest.filterwarnings does not suppress these.
|
||||
if actual_warning.category == ResourceWarning:
|
||||
# GH 44732: Don't make the CI flaky by filtering SSL-related
|
||||
# ResourceWarning from dependencies
|
||||
if "unclosed <ssl.SSLSocket" in str(actual_warning.message):
|
||||
continue
|
||||
# GH 44844: Matplotlib leaves font files open during the entire process
|
||||
# upon import. Don't make CI flaky if ResourceWarning raised
|
||||
# due to these open files.
|
||||
if any("matplotlib" in mod for mod in sys.modules):
|
||||
continue
|
||||
if PY311 and actual_warning.category == EncodingWarning:
|
||||
# EncodingWarnings are checked in the CI
|
||||
# pyproject.toml errors on EncodingWarnings in pandas
|
||||
# Ignore EncodingWarnings from other libraries
|
||||
continue
|
||||
extra_warnings.append(
|
||||
(
|
||||
actual_warning.category.__name__,
|
||||
actual_warning.message,
|
||||
actual_warning.filename,
|
||||
actual_warning.lineno,
|
||||
)
|
||||
)
|
||||
|
||||
if extra_warnings:
|
||||
raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}")
|
||||
|
||||
|
||||
def _is_unexpected_warning(
|
||||
actual_warning: warnings.WarningMessage,
|
||||
expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
|
||||
) -> bool:
|
||||
"""Check if the actual warning issued is unexpected."""
|
||||
if actual_warning and not expected_warning:
|
||||
return True
|
||||
expected_warning = cast(type[Warning], expected_warning)
|
||||
return bool(not issubclass(actual_warning.category, expected_warning))
|
||||
|
||||
|
||||
def _assert_raised_with_correct_stacklevel(
|
||||
actual_warning: warnings.WarningMessage,
|
||||
) -> None:
|
||||
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
|
||||
frame = inspect.currentframe()
|
||||
for _ in range(4):
|
||||
frame = frame.f_back # type: ignore[union-attr]
|
||||
try:
|
||||
caller_filename = inspect.getfile(frame) # type: ignore[arg-type]
|
||||
finally:
|
||||
# See note in
|
||||
# https://docs.python.org/3/library/inspect.html#inspect.Traceback
|
||||
del frame
|
||||
msg = (
|
||||
"Warning not set with correct stacklevel. "
|
||||
f"File where warning is raised: {actual_warning.filename} != "
|
||||
f"{caller_filename}. Warning message: {actual_warning.message}"
|
||||
)
|
||||
assert actual_warning.filename == caller_filename, msg
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Helpers for sharing tests between DataFrame/Series
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas._typing import DtypeObj
|
||||
|
||||
|
||||
def get_dtype(obj) -> DtypeObj:
|
||||
if isinstance(obj, DataFrame):
|
||||
# Note: we are assuming only one column
|
||||
return obj.dtypes.iat[0]
|
||||
else:
|
||||
return obj.dtype
|
||||
|
||||
|
||||
def get_obj(df: DataFrame, klass):
|
||||
"""
|
||||
For sharing tests using frame_or_series, either return the DataFrame
|
||||
unchanged or return it's first column as a Series.
|
||||
"""
|
||||
if klass is DataFrame:
|
||||
return df
|
||||
return df._ixs(0, axis=1)
|
||||
@@ -0,0 +1,257 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from pandas._config import using_copy_on_write
|
||||
|
||||
from pandas.compat import PYPY
|
||||
from pandas.errors import ChainedAssignmentError
|
||||
|
||||
from pandas import set_option
|
||||
|
||||
from pandas.io.common import get_handle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from pandas._typing import (
|
||||
BaseBuffer,
|
||||
CompressionOptions,
|
||||
FilePath,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def decompress_file(
|
||||
path: FilePath | BaseBuffer, compression: CompressionOptions
|
||||
) -> Generator[IO[bytes], None, None]:
|
||||
"""
|
||||
Open a compressed file and return a file object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
The path where the file is read from.
|
||||
|
||||
compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd', None}
|
||||
Name of the decompression to use
|
||||
|
||||
Returns
|
||||
-------
|
||||
file object
|
||||
"""
|
||||
with get_handle(path, "rb", compression=compression, is_text=False) as handle:
|
||||
yield handle.handle
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_timezone(tz: str) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager for temporarily setting a timezone.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tz : str
|
||||
A string representing a valid timezone.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from datetime import datetime
|
||||
>>> from dateutil.tz import tzlocal
|
||||
>>> tzlocal().tzname(datetime(2021, 1, 1)) # doctest: +SKIP
|
||||
'IST'
|
||||
|
||||
>>> with set_timezone('US/Eastern'):
|
||||
... tzlocal().tzname(datetime(2021, 1, 1))
|
||||
...
|
||||
'EST'
|
||||
"""
|
||||
import time
|
||||
|
||||
def setTZ(tz) -> None:
|
||||
if tz is None:
|
||||
try:
|
||||
del os.environ["TZ"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
os.environ["TZ"] = tz
|
||||
time.tzset()
|
||||
|
||||
orig_tz = os.environ.get("TZ")
|
||||
setTZ(tz)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
setTZ(orig_tz)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ensure_clean(
|
||||
filename=None, return_filelike: bool = False, **kwargs: Any
|
||||
) -> Generator[Any, None, None]:
|
||||
"""
|
||||
Gets a temporary path and agrees to remove on close.
|
||||
|
||||
This implementation does not use tempfile.mkstemp to avoid having a file handle.
|
||||
If the code using the returned path wants to delete the file itself, windows
|
||||
requires that no program has a file handle to it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str (optional)
|
||||
suffix of the created file.
|
||||
return_filelike : bool (default False)
|
||||
if True, returns a file-like which is *always* cleaned. Necessary for
|
||||
savefig and other functions which want to append extensions.
|
||||
**kwargs
|
||||
Additional keywords are passed to open().
|
||||
|
||||
"""
|
||||
folder = Path(tempfile.gettempdir())
|
||||
|
||||
if filename is None:
|
||||
filename = ""
|
||||
filename = str(uuid.uuid4()) + filename
|
||||
path = folder / filename
|
||||
|
||||
path.touch()
|
||||
|
||||
handle_or_str: str | IO = str(path)
|
||||
encoding = kwargs.pop("encoding", None)
|
||||
if return_filelike:
|
||||
kwargs.setdefault("mode", "w+b")
|
||||
if encoding is None and "b" not in kwargs["mode"]:
|
||||
encoding = "utf-8"
|
||||
handle_or_str = open(path, encoding=encoding, **kwargs)
|
||||
|
||||
try:
|
||||
yield handle_or_str
|
||||
finally:
|
||||
if not isinstance(handle_or_str, str):
|
||||
handle_or_str.close()
|
||||
if path.is_file():
|
||||
path.unlink()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_csv_dialect(name: str, **kwargs) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager to temporarily register a CSV dialect for parsing CSV.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the dialect.
|
||||
kwargs : mapping
|
||||
The parameters for the dialect.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError : the name of the dialect conflicts with a builtin one.
|
||||
|
||||
See Also
|
||||
--------
|
||||
csv : Python's CSV library.
|
||||
"""
|
||||
import csv
|
||||
|
||||
_BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
|
||||
|
||||
if name in _BUILTIN_DIALECTS:
|
||||
raise ValueError("Cannot override builtin dialect.")
|
||||
|
||||
csv.register_dialect(name, **kwargs)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
csv.unregister_dialect(name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_numexpr(use, min_elements=None) -> Generator[None, None, None]:
|
||||
from pandas.core.computation import expressions as expr
|
||||
|
||||
if min_elements is None:
|
||||
min_elements = expr._MIN_ELEMENTS
|
||||
|
||||
olduse = expr.USE_NUMEXPR
|
||||
oldmin = expr._MIN_ELEMENTS
|
||||
set_option("compute.use_numexpr", use)
|
||||
expr._MIN_ELEMENTS = min_elements
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
expr._MIN_ELEMENTS = oldmin
|
||||
set_option("compute.use_numexpr", olduse)
|
||||
|
||||
|
||||
def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=()):
|
||||
from pandas._testing import assert_produces_warning
|
||||
|
||||
if not warn:
|
||||
from contextlib import nullcontext
|
||||
|
||||
return nullcontext()
|
||||
|
||||
if PYPY and not extra_warnings:
|
||||
from contextlib import nullcontext
|
||||
|
||||
return nullcontext()
|
||||
elif PYPY and extra_warnings:
|
||||
return assert_produces_warning(
|
||||
extra_warnings,
|
||||
match="|".join(extra_match),
|
||||
)
|
||||
else:
|
||||
if using_copy_on_write():
|
||||
warning = ChainedAssignmentError
|
||||
match = (
|
||||
"A value is trying to be set on a copy of a DataFrame or Series "
|
||||
"through chained assignment"
|
||||
)
|
||||
else:
|
||||
warning = FutureWarning # type: ignore[assignment]
|
||||
# TODO update match
|
||||
match = "ChainedAssignmentError"
|
||||
if extra_warnings:
|
||||
warning = (warning, *extra_warnings) # type: ignore[assignment]
|
||||
return assert_produces_warning(
|
||||
warning,
|
||||
match="|".join((match, *extra_match)),
|
||||
)
|
||||
|
||||
|
||||
def assert_cow_warning(warn=True, match=None, **kwargs):
|
||||
"""
|
||||
Assert that a warning is raised in the CoW warning mode.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
warn : bool, default True
|
||||
By default, check that a warning is raised. Can be turned off by passing False.
|
||||
match : str
|
||||
The warning message to match against, if different from the default.
|
||||
kwargs
|
||||
Passed through to assert_produces_warning
|
||||
"""
|
||||
from pandas._testing import assert_produces_warning
|
||||
|
||||
if not warn:
|
||||
from contextlib import nullcontext
|
||||
|
||||
return nullcontext()
|
||||
|
||||
if not match:
|
||||
match = "Setting a value on a view"
|
||||
|
||||
return assert_produces_warning(FutureWarning, match=match, **kwargs)
|
||||
Reference in New Issue
Block a user