Skip to content

Commit

Permalink
plot_timeseries: wide-format refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Nov 1, 2024
1 parent 430bd3e commit c55a59e
Showing 1 changed file with 64 additions and 71 deletions.
135 changes: 64 additions & 71 deletions src/pytimetk/plot/plot_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def plot_timeseries(
.plot_timeseries(
date_column = 'date',
value_column = ['sales', 'expenses', 'profit'],
# color_column = ['sales', 'expenses', 'profit'],
color_column = ['sales', 'expenses', 'profit'],
facet_ncol = 2,
smooth = True,
y_intercept = 0,
Expand All @@ -447,7 +447,7 @@ def plot_timeseries(
# Common checks
check_dataframe_or_groupby(data)
check_date_column(data, date_column)

if isinstance(value_column, list):
for col in value_column:
check_value_column(data, col)
Expand All @@ -469,80 +469,72 @@ def plot_timeseries(
y_intercept_color = name_to_hex(y_intercept_color)
x_intercept_color = name_to_hex(x_intercept_color)

# Handle DataFrames
if isinstance(data, pd.DataFrame):
group_names = None
data = data.copy()
# Initialize group_names
group_names = None

# Reshape data if value_column is a list
if isinstance(value_column, list):
data = pd.melt(
data,
id_vars=[date_column],
value_vars=value_column,
var_name='__value_column',
value_name='__value'
)
group_names = '__value_column'
value_column = '__value' # Update value_column to the new column name
if color_column is value_column:
color_column = '__value_column' # Use variable names as the color column
else:
# Ensure value_column is in data
if value_column not in data.columns:
raise ValueError(f"value_column '{value_column}' not found in DataFrame.")

# Handle smoother
if smooth:
data['__smooth'] = np.nan
if color_column is None:
sorted_data = data.sort_values(by=date_column)
x = np.arange(len(sorted_data))
y = sorted_data[value_column].to_numpy()
data['__smooth'] = lowess(y, x, frac=smooth_frac, return_sorted=False)
else:
for name, group in data.groupby(color_column):
sorted_group = group.sort_values(by=date_column)
x = np.arange(len(sorted_group))
y = sorted_group[value_column].to_numpy()
smoothed = lowess(y, x, frac=smooth_frac)
data.loc[sorted_group.index, '__smooth'] = smoothed[:, 1]

# Handle GroupBy objects
elif isinstance(data, pd.core.groupby.generic.DataFrameGroupBy):
# Handle DataFrames and GroupBy objects
if isinstance(data, pd.core.groupby.generic.DataFrameGroupBy):
group_names = data.grouper.names
data = data.obj.copy()
elif isinstance(data, pd.DataFrame):
data = data.copy()
else:
raise ValueError("data must be a DataFrame or GroupBy object")

# Reshape data if value_column is a list
if isinstance(value_column, list):
data = pd.melt(
data,
id_vars=group_names + [date_column],
value_vars=value_column,
var_name='__value_column',
value_name='__value'
)
value_column = '__value' # Update value_column to the new column name
if color_column is None:
color_column = '__value_column' # Use variable names as the color column
# Reshape data if value_column is a list (wide format)
if isinstance(value_column, list):
if color_column == value_column:
color_column = '__value_column' # Use variable names as the color column

if group_names is None:
id_vars = [date_column]
else:
id_vars = group_names + [date_column]

data = pd.melt(
data,
id_vars=id_vars,
value_vars=value_column,
var_name='__value_column',
value_name='__value'
)

value_column = '__value' # Update value_column to the new column name

if group_names is None:
group_names = ['__value_column']
else:
group_names = group_names + ['__value_column']

legend_show = False
else:
# Ensure value_column is in data
if value_column not in data.columns:
raise ValueError(f"value_column '{value_column}' not found in DataFrame.")

# Handle smoother
if smooth:
data['__smooth'] = np.nan
if color_column is None:
for name, group in data.groupby(group_names):
sorted_group = group.sort_values(by=date_column)
x = np.arange(len(sorted_group))
y = sorted_group[value_column].to_numpy()
smoothed = lowess(y, x, frac=smooth_frac)
data.loc[sorted_group.index, '__smooth'] = smoothed[:, 1]
else:
for name, group in data.groupby(group_names + [color_column]):
sorted_group = group.sort_values(by=date_column)
x = np.arange(len(sorted_group))
y = sorted_group[value_column].to_numpy()
smoothed = lowess(y, x, frac=smooth_frac)
data.loc[sorted_group.index, '__smooth'] = smoothed[:, 1]
# Handle smoother
if smooth:
data['__smooth'] = np.nan
# Determine grouping columns
grouping_columns = []
if group_names is not None:
grouping_columns.extend(group_names)
if color_column is not None and color_column not in grouping_columns:
grouping_columns.append(color_column)
if grouping_columns:
# Apply smoother per group
for _, group in data.groupby(grouping_columns):
sorted_group = group.sort_values(by=date_column)
x = np.arange(len(sorted_group))
y = sorted_group[value_column].to_numpy()
smoothed = lowess(y, x, frac=smooth_frac)
data.loc[sorted_group.index, '__smooth'] = smoothed[:, 1]
else:
# Apply smoother to the whole data
sorted_data = data.sort_values(by=date_column)
x = np.arange(len(sorted_data))
y = sorted_data[value_column].to_numpy()
data['__smooth'] = lowess(y, x, frac=smooth_frac, return_sorted=False)

# Handle color palette
if color_palette is None:
Expand All @@ -558,6 +550,7 @@ def plot_timeseries(
else:
raise ValueError("Invalid `color_palette` parameter. It must be a dictionary, list, or string.")


# Engine
if engine in ['plotnine', 'matplotlib']:
fig = _plot_timeseries_plotnine(
Expand Down

0 comments on commit c55a59e

Please sign in to comment.