Skip to content

Commit

Permalink
Merge pull request #4 from nccgroup/pyplot-to-agg
Browse files Browse the repository at this point in the history
Replace pyplot interface with the Figure interface and explicitly set matplotlib backend.
  • Loading branch information
neonbunny authored Jul 8, 2024
2 parents 786fe54 + d653b5f commit f92228e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 61 deletions.
78 changes: 32 additions & 46 deletions event_tracker/views_credentials.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import csv
import io
import itertools
Expand All @@ -22,8 +21,8 @@
from django.views.generic import FormView, ListView, CreateView, UpdateView, DeleteView, TemplateView
from django_datatables_view.base_datatable_view import BaseDatatableView
from djangoplugins.models import ENABLED
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from neo4j.exceptions import ClientError

Expand Down Expand Up @@ -290,13 +289,12 @@ def get_filtered_creds(request):
def password_complexity_piechart(request, task_id):
credential_per_account, _, _, system, enabled = CredentialStatsView.get_filtered_creds(request)

with plot_password_complexity_piechart(credential_per_account, enabled, system) as fig:
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')
fig = plot_password_complexity_piechart(credential_per_account, enabled, system)
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')

return response
return response

@contextlib.contextmanager
def plot_password_complexity_piechart(credential_per_account, enabled, system):
credential_per_account.update(complexity=Case(
When(secret="", then=Value("blank")),
Expand Down Expand Up @@ -340,7 +338,7 @@ def plot_password_complexity_piechart(credential_per_account, enabled, system):
upperalphaspecialnum=Count("pk", filter=Q(complexity="upperalphaspecialnum")),
mixedalphaspecialnum=Count("pk", filter=Q(complexity="mixedalphaspecialnum")),
))
fig = plt.figure(figsize=(10, 8))
fig = Figure(figsize=(10, 8))
ax = fig.subplots(2, height_ratios=[3, 1])
piesegments = [counts['blank'], counts['numeric'], counts['special'],
counts['loweralpha'], counts['upperalpha'], counts['mixedalpha'],
Expand Down Expand Up @@ -368,28 +366,25 @@ def plot_password_complexity_piechart(credential_per_account, enabled, system):
title="Key",
loc="upper center",
ncol=2)
try:
yield fig
finally:
plt.close(fig)

return fig


def password_structure_piechart(request, task_id):
credential_per_account, _, _, system, enabled = CredentialStatsView.get_filtered_creds(request)

with plot_password_structure_piechart(credential_per_account, enabled, system) as fig:
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')
fig = plot_password_structure_piechart(credential_per_account, enabled, system)
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')

return response
return response

@contextlib.contextmanager
def plot_password_structure_piechart(credential_per_account, enabled, system):
calculate_char_masks(credential_per_account)

structurecounts = credential_per_account.values("structure").annotate(count=Count("structure")).order_by("count")

fig = plt.figure(figsize=(10, 8))
fig = Figure(figsize=(10, 8))
ax = fig.subplots(2, height_ratios=[3, 1])
piesegments = structurecounts.values_list("count", flat=True)
pielabels = structurecounts.values_list("structure", flat=True)
Expand All @@ -408,24 +403,21 @@ def plot_password_structure_piechart(credential_per_account, enabled, system):
title="Key",
loc="upper center",
ncol=2)
try:
yield fig
finally:
plt.close(fig)

return fig


def password_length_chart(request, task_id):
credential_per_account, _, _, system, enabled = CredentialStatsView.get_filtered_creds(request)

with plot_password_length_chart(credential_per_account, enabled, system) as fig:
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')
fig = plot_password_length_chart(credential_per_account, enabled, system)
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')

return response
return response

@contextlib.contextmanager
def plot_password_length_chart(credential_per_account, enabled, system):
fig = plt.figure(figsize=(7, 8))
fig = Figure(figsize=(7, 8))
ax = fig.subplots(2)

lengths = credential_per_account.annotate(length=Length("secret")).order_by("length").values("length").annotate(
Expand Down Expand Up @@ -461,27 +453,24 @@ def plot_password_length_chart(credential_per_account, enabled, system):
l, y, s in zip(x, y, percents)],
loc="upper center",
ncol=2)
try:
yield fig
finally:
plt.close(fig)

return fig


def password_age_chart(request, task_id):
statsfilter = request.session.get('credentialstatsfilter', {})
system = statsfilter.get("system", None) or None
enabled = statsfilter.get("enabled", False)

with plot_password_age_chart(enabled, system) as fig:
if fig:
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')
fig = plot_password_age_chart(enabled, system)
if fig:
response = HttpResponse(content_type='image/png')
fig.savefig(response, format='png')

return response
else:
return HttpResponseNotFound()
return response
else:
return HttpResponseNotFound()

@contextlib.contextmanager
def plot_password_age_chart(enabled, system):
# Password age
password_ages = []
Expand All @@ -492,7 +481,7 @@ def plot_password_age_chart(enabled, system):
password_ages = session.execute_read(CredentialStatsView._bucket_password_ages, system, enabled)
#TODO merge multiple password_ages from different servers, rather than overwriting
if password_ages:
fig = plt.figure(figsize=(8, 7))
fig = Figure(figsize=(8, 7))
ax = fig.subplots(2)

data = np.array(password_ages)
Expand All @@ -518,12 +507,9 @@ def plot_password_age_chart(enabled, system):
loc="upper center",
ncol=2)

try:
yield fig
finally:
plt.close(fig)
return fig
else:
yield None
return None

class CredentialListView(PermissionRequiredMixin, ListView):
permission_required = 'event_tracker.view_credential'
Expand Down
32 changes: 17 additions & 15 deletions graphical_reports/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from django.http import HttpResponse, HttpResponseNotFound
from django.shortcuts import get_object_or_404
from django.views import View
from matplotlib import pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.figure import Figure
from matplotlib.ticker import PercentFormatter

from event_tracker.models import Task, AttackTactic, AttackSubTechnique, AttackTechnique
Expand Down Expand Up @@ -51,7 +51,8 @@ def get(self, request, task_id, **kwargs):
colors1 = [f'C{i}' for i in range(len(labels))]

# create a horizontal plot
fig, ax = plt.subplots()
fig = Figure()
ax = fig.subplots()

# Add bands for each weekend day in the range
date_range = task.event_set.all().aggregate(start=Min("timestamp"), end=Max("timestamp"))
Expand All @@ -66,10 +67,10 @@ def get(self, request, task_id, **kwargs):
ax.eventplot(data, colors=colors1, lineoffsets=labels)

ax.grid(axis="y")
plt.xticks(rotation=45/2)
plt.tight_layout()
ax.xaxis.set_tick_params(labelrotation=45/2)
fig.tight_layout()

plt.savefig(response, dpi=300, transparent=True)
fig.savefig(response, dpi=300, transparent=True)
return response


Expand All @@ -85,7 +86,8 @@ def get(self, request, task_id, **kwargs):

exercise_days = task.event_set.dates("timestamp", "day")

fig, ax = plt.subplots()
fig = Figure()
ax = fig.subplots()

labels = []
neither_detected_nor_prevented = []
Expand Down Expand Up @@ -183,9 +185,9 @@ def get(self, request, task_id, **kwargs):
ax.bar(labels, neither_detected_nor_prevented, width, label=f"Neither Detected Nor Prevented {(total_summary['neither_detected_nor_prevented'] / total_summary['total_events']):.2%}", color=badness_colormap(0.0))

ax.legend()
plt.xticks(rotation=45/2)
ax.xaxis.set_tick_params(labelrotation=45/2)

plt.savefig(response, dpi=300, transparent=True)
fig.savefig(response, dpi=300, transparent=True)
return response


Expand All @@ -211,10 +213,10 @@ def get_context_data(self, **kwargs):
order_by=[F("icount").asc(), ]
))

plt = self.generate_heatmap(context["event_tactics"], percentiles, include_subtechniques)
fig = self.generate_heatmap(context["event_tactics"], percentiles, include_subtechniques)

buffer = io.BytesIO()
plt.savefig(buffer, format='png')
fig.savefig(buffer, format='png')
buffer.seek(0)
context['heatmap_b64'] = base64.b64encode(buffer.read()).decode("ASCII")

Expand Down Expand Up @@ -251,10 +253,10 @@ def generate_heatmap(self, tactics, percentiles, include_subtechniques):
# Define main "figure" (i.e. canvas)
figure_height_inches = ((np.sum(subplot_heights) * 3) + len(tactics)) / 3 + 0.9

fig, ax = plt.subplots(ncols=1, nrows=len(tactics) + 2,
figsize=(10, figure_height_inches), height_ratios=subplot_heights)
fig = Figure(figsize=(10, figure_height_inches))
ax = fig.subplots(ncols=1, nrows=len(tactics) + 2, height_ratios=subplot_heights)

plt.colorbar(ScalarMappable(cmap=intensity_colormap), cax=fig.get_axes()[0], orientation="horizontal", format=PercentFormatter(xmax=1))
fig.colorbar(ScalarMappable(cmap=intensity_colormap), cax=fig.get_axes()[0], orientation="horizontal", format=PercentFormatter(xmax=1))
ax[0].set_title("Key: Number of attempts (as percentile)")
ax[1].set_axis_off()

Expand Down Expand Up @@ -309,6 +311,6 @@ def generate_heatmap(self, tactics, percentiles, include_subtechniques):

fig.tight_layout()

plt.subplots_adjust(hspace=0.6)
fig.subplots_adjust(hspace=0.6)

return plt
return fig
4 changes: 4 additions & 0 deletions stepping_stones/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

from pathlib import Path
import matplotlib

# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
Expand Down Expand Up @@ -211,3 +212,6 @@
"web-share": [],
"xr-spatial-tracking": [],
}

# Define backend for matplotlib. Ensure a non-interactive backend is chosen to avoid dangling resources
matplotlib.use('agg')

0 comments on commit f92228e

Please sign in to comment.