diff --git a/amlb/datautils.py b/amlb/datautils.py index 7378b7b33..94aab0981 100644 --- a/amlb/datautils.py +++ b/amlb/datautils.py @@ -59,18 +59,18 @@ def read_csv(path, nrows=None, header=True, index=False, as_data_frame=True, dty return df if as_data_frame else df.values -def write_csv( +def write_csv( # type: ignore[no-untyped-def] data: pd.DataFrame | dict | list | np.ndarray, path, header: bool = True, columns: Iterable[str] | None = None, index: bool = False, append: bool = False -) -> None: # type: ignore # path required pandas internal types - if is_data_frame(data): +) -> None: + if isinstance(data, pd.DataFrame): data_frame = data else: - data_frame = to_data_frame(data, columns=columns) + data_frame = to_data_frame(data, column_names=columns) header = header and columns is not None touch(path) data_frame.to_csv( @@ -139,17 +139,17 @@ def is_data_frame(df: object) -> bool: return isinstance(df, pd.DataFrame) -def to_data_frame(obj: object, columns: Iterable[str]| None=None): +def to_data_frame(obj: object, column_names: Iterable[str]| None=None) -> pd.DataFrame: if obj is None: return pd.DataFrame() - elif isinstance(obj, dict): + columns = list(column_names) if column_names else None + if isinstance(obj, dict): orient = cast(Literal['columns', 'index'], 'columns' if columns is None else 'index') return pd.DataFrame.from_dict(obj, columns=columns, orient=orient) - elif isinstance(obj, (list, np.ndarray)): + if isinstance(obj, (list, np.ndarray)): return pd.DataFrame.from_records(obj, columns=columns) - else: - raise ValueError("Object should be a dictionary {col1:values, col2:values, ...} " - "or an array of dictionary-like objects [{col1:val, col2:val}, {col1:val, col2:val}, ...].") + raise ValueError("Object should be a dictionary {col1:values, col2:values, ...} " + "or an array of dictionary-like objects [{col1:val, col2:val}, {col1:val, col2:val}, ...].") class Encoder(TransformerMixin):