Skip to content

Commit

Permalink
Merge pull request #176 from AI-SDC/histograms
Browse files Browse the repository at this point in the history
histogram
  • Loading branch information
jim-smith authored Oct 18, 2023
2 parents 91b8688 + 94944d7 commit 7cf4884
Show file tree
Hide file tree
Showing 4 changed files with 883 additions and 194 deletions.
189 changes: 189 additions & 0 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,195 @@ def survival_plot( # pylint: disable=too-many-arguments,too-many-locals
)
return plot

def hist( # pylint: disable=too-many-arguments,too-many-locals
self,
data,
column,
by=None,
grid=True,
xlabelsize=None,
xrot=None,
ylabelsize=None,
yrot=None,
ax=None,
sharex=False,
sharey=False,
figsize=None,
layout=None,
bins=10,
backend=None,
legend=False,
filename="histogram.png",
**kwargs,
):
"""Creates a histogram from a single column.
The dataset and the column's name should be passed to the function as parameters.
If more than one column is used the histogram will not be calculated.
To save the histogram plot to a file, the user can specify a filename otherwise
'histogram.png' will be used as the filename. A number will be appended automatically
to the filename to avoid overwriting the files.
Returns
-------
matplotlib.Axes
Parameters:
-----------
data : DataFrame
The pandas object holding the data.
column : str
The column that will be used to plot the histogram.
by : object, optional
If passed, then used to form histograms for separate groups.
grid : bool, default True
Whether to show axis grid lines.
xlabelsize : int, default None
If specified changes the x-axis label size.
xrot : float, default None
Rotation of x axis labels. For example, a value of 90 displays
the x labels rotated 90 degrees clockwise.
ylabelsize : int, default None
If specified changes the y-axis label size.
yrot : float, default None
Rotation of y axis labels. For example, a value of 90 displays
the y labels rotated 90 degrees clockwise.
ax : Matplotlib axes object, default None
The axes to plot the histogram on.
sharex : bool, default True if ax is None else False
In case subplots=True, share x axis and set some x axis labels to invisible;
defaults to True if ax is None otherwise False if an ax is passed in.
Note that passing in both an ax and sharex=True will alter all x axis
labels for all subplots in a figure.
sharey : bool, default False
In case subplots=True, share y axis and set some y axis labels to invisible.
figsize : tuple, optional
The size in inches of the figure to create.
Uses the value in matplotlib.rcParams by default.
layout : tuple, optional
Tuple of (rows, columns) for the layout of the histograms.
bins : int or sequence, default 10
Number of histogram bins to be used. If an integer is given, bins + 1 bin edges are
calculated and returned. If bins is a sequence, gives bin edges,
including left edge of first bin and right edge of last bin.
backend : str, default None
Backend to use instead of the backend specified in the option plotting.backend.
For instance, ‘matplotlib’. Alternatively, to specify the plotting.backend for the
whole session, set pd.options.plotting.backend.
legend : bool, default False
Whether to show the legend.
filename:
The name of the file where the plot will be saved.
"""
logger.debug("hist()")
command: str = utils.get_command("hist()", stack())

if isinstance(data, list): # pragma: no cover
logger.info(
"Calculating histogram for more than one columns is "
"not currently supported. Please do each column separately."
)
return

freq, _ = np.histogram( # pylint: disable=too-many-function-args
data[column], bins, range=(data[column].min(), data[column].max())
)

# threshold check
threshold_mask = freq < THRESHOLD

# plot the histogram
if np.any(threshold_mask): # the column is disclosive
status = "fail"
if self.suppress:
logger.warning(
"Histogram will not be shown as the %s column is disclosive.",
column,
)
else: # pragma: no cover
data.hist(
column=column,
by=by,
grid=grid,
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
ax=ax,
sharex=sharex,
sharey=sharey,
figsize=figsize,
layout=layout,
bins=bins,
backend=backend,
legend=legend,
**kwargs,
)
else:
status = "review"
data.hist(
column=column,
by=by,
grid=grid,
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
ax=ax,
sharex=sharex,
sharey=sharey,
figsize=figsize,
layout=layout,
bins=bins,
backend=backend,
legend=legend,
**kwargs,
)
logger.info("status: %s", status)

# create the summary
min_value = data[column].min()
max_value = data[column].max()
summary = (
f"Please check the minimum and the maximum values. "
f"The minimum value of the {column} column is: {min_value}. "
f"The maximum value of the {column} column is: {max_value}"
)

# create the acro_artifacts directory to save the plot in it
try:
os.makedirs("acro_artifacts")
logger.debug("Directory acro_artifacts created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")

# create a unique filename with number to avoid overwrite
filename, extension = os.path.splitext(filename)
if not extension: # pragma: no cover
logger.info("Please provide a valid file extension")
return
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
):
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"

# save the plot to the acro artifacts directory
plt.savefig(unique_filename)

# record output
self.results.add(
status=status,
output_type="histogram",
properties={"method": "histogram"},
sdc={},
command=command,
summary=summary,
outcome=pd.DataFrame(),
output=[os.path.normpath(unique_filename)],
)


def create_crosstab_masks( # pylint: disable=too-many-arguments,too-many-locals
index,
Expand Down
25 changes: 7 additions & 18 deletions acro/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,11 @@ def serialize_output(self, path: str = "outputs") -> list[str]:
if os.path.exists(filename):
shutil.copy(filename, path)
output.append(Path(filename).name)
if self.output_type == "survival plot":
if self.output_type in ["survival plot", "histogram"]:
for filename in self.output:
output.append(Path(filename).name)
if os.path.exists(filename):
output.append(Path(filename).name)
shutil.copy(filename, path)
return output

def __str__(self) -> str:
Expand Down Expand Up @@ -446,9 +448,10 @@ def finalise(self, path: str, ext: str) -> None:
self.finalise_excel(path)
else:
raise ValueError("Invalid file extension. Options: {json, xlsx}")
if os.path.exists("acro_artifacts"):
add_acro_artifacts(path)
self.write_checksums(path)
# check if the directory acro_artifacts exists and delete it
if os.path.exists("acro_artifacts"):
shutil.rmtree("acro_artifacts")
logger.info("outputs written to: %s", path)

def finalise_json(self, path: str) -> None:
Expand Down Expand Up @@ -565,20 +568,6 @@ def write_checksums(self, path: str) -> None:
logger.debug("There is no file to do the checksums") # pragma: no cover


def add_acro_artifacts(path: str) -> None:
"""Copy any file from the acro_artifacts directory to the output
directory then delete the directory.
Parameters
----------
path : str
Name of the folder that files are to be written.
"""
for filename in os.listdir("acro_artifacts"):
shutil.copy(f"acro_artifacts/{filename}", path)
shutil.rmtree("acro_artifacts")


def load_records(path: str) -> Records:
"""Loads outputs from a JSON file.
Expand Down
Loading

0 comments on commit 7cf4884

Please sign in to comment.