Skip to content

Commit

Permalink
Rearrange evaluations (#9)
Browse files Browse the repository at this point in the history
* standardized output file names, created standalone mintaka-wikidata folder, moved parseCSVFile to own file

* reproduce images in submission

* progress on correctness breakdown plots

* formatted legend strings, fixed order for all plots

* stack bar charts checkpoint

* filled in missing data

* cleaned up and commented

* summed percentages

* tweaked stacked bar chart

* updated stacked bar chart visuals

---------

Co-authored-by: Harry Li <harry.li@ll.mit.edu>
  • Loading branch information
harryli0088 and Harry Li authored Dec 13, 2024
1 parent 486c9bc commit f015768
Show file tree
Hide file tree
Showing 15 changed files with 243 additions and 121 deletions.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import fs from "fs"
import papaparse from "papaparse"
import { EvaluationOutputRowType } from "./mintakaEvaluation";
import { parseCSVFile } from "utils/parseCSVFile";


calculateMetrics("./LinkQ Evaluation Output.csv","./Plain LLM Evaluation Output.csv","./output.csv")
calculateMetrics("./data/linkq-evaluation-results.csv","./data/plainllm-evaluation-results.csv","./data/aggregated-evaluation-results.csv")

type MetricType = {
complexityType: string,
Expand Down Expand Up @@ -41,6 +42,7 @@ async function calculateMetrics(
parseCSVFile<EvaluationOutputRowType>(linkqDataPath),
parseCSVFile<EvaluationOutputRowType>(plainLLMDataPath),
])
console.log("linkqData",linkqData)
console.log("Parsed data")
if(linkqData.length !== plainLLMData.length) {
throw new Error(`linkqData and plainLLMData lengths do not match`)
Expand Down Expand Up @@ -182,19 +184,6 @@ function isSyntaxCorrect(row: EvaluationOutputRowType) {
return value === "YES"
}


export function parseCSVFile<T>(path:string):Promise<T[]> {
return new Promise((resolve) => {
const file = fs.createReadStream(path)
papaparse.parse<T>(file, {
header: true,
complete: function(results) {
resolve(results.data)
}
})
})
}

function meanAndStd(numArray: number[]) {
let min = Infinity
let max = -Infinity
Expand Down
1 change: 1 addition & 0 deletions src/utils/evaluations/mintaka-wikidata/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.csv
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ if (process.env.HTTPS_PROXY) {
import fs from "fs"
import papaparse from "papaparse"

import { ChatGPTAPI } from "../ChatGPTAPI"
import { tryParsingOutQuery } from "../tryParsingOutQuery"
import { runQuery } from "../knowledgeBase/runQuery"
import { summarizeQueryResults } from "../summarizeQueryResults"
import { getEntityDataFromQuery } from "../knowledgeBase/getEntityData"
import { formatSparqlResultsAsString } from "../formatSparqlResultsAsString"
import { ChatGPTAPI } from "../../ChatGPTAPI"
import { tryParsingOutQuery } from "../../tryParsingOutQuery"
import { runQuery } from "../../knowledgeBase/runQuery"
import { summarizeQueryResults } from "../../summarizeQueryResults"
import { getEntityDataFromQuery } from "../../knowledgeBase/getEntityData"
import { formatSparqlResultsAsString } from "../../formatSparqlResultsAsString"
import { QUESTIONS } from "./questions"

Check failure on line 24 in src/utils/evaluations/mintaka-wikidata/mintakaEvaluation.ts

View workflow job for this annotation

GitHub Actions / deploy

Cannot find module './questions' or its corresponding type declarations.

Check failure on line 24 in src/utils/evaluations/mintaka-wikidata/mintakaEvaluation.ts

View workflow job for this annotation

GitHub Actions / deploy

Cannot find module './questions' or its corresponding type declarations.
import { INITIAL_SYSTEM_MESSAGE } from "../knowledgeBase/prompts"
import { queryBuildingWorkflow } from "../queryBuildingWorkflow"
import { INITIAL_SYSTEM_MESSAGE } from "../../knowledgeBase/prompts"
import { queryBuildingWorkflow } from "../../queryBuildingWorkflow"

import { loadEnv } from 'vite'
const ENV = loadEnv("development","../../../")
Expand Down Expand Up @@ -124,7 +124,7 @@ async function runMintakaEvaluation(

export async function runLinkQMintakaEvaluation() {
return await runMintakaEvaluation(
`LinkQ Evaluation Output ${new Date().getTime()}.csv`,
`linkq-evaluation-output-${new Date().getTime()}.csv`,
async (chatGPT:ChatGPTAPI, question:string) => {
//force the LLM to start the query building workflow
chatGPT.messages = [
Expand Down Expand Up @@ -155,7 +155,7 @@ export async function runLinkQMintakaEvaluation() {

export async function runPlainLLMMintakaEvaluation() {
return await runMintakaEvaluation(
`Plain LLM Evaluation Output ${new Date().getTime()}.csv`,
`plainllm-evaluation-results-${new Date().getTime()}.csv`,
async (chatGPT:ChatGPTAPI, question:string) => {
return await chatGPT.sendMessages([
{
Expand Down
File renamed without changes.
16 changes: 16 additions & 0 deletions src/utils/evaluations/mintaka-wikidata/plot/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
1. Download the evaluation results from TODO
2. Place the CSVs from the `../data` folder
3. Rename the CSVs, if applicable:
- 'Evaluation for CHI - Aggregated Results': 'aggregated-evaluation-results.csv'
- 'Evaluation for CHI - Plain LLM Evaluation Output': 'plainllm-evaluation-results.csv'
- 'Evaluation for CHI - LinkQ Evaluation Output': 'linq-evaluation-results.csv'
4. Create a new conda environment, adivate it, and download the requirements
```
conda create --name linkq python=3.12
conda activate linkq
pip install -r requirements.txt
```
5. Run the script to generate the plots
```
python validation_figures.py
```
File renamed without changes.
191 changes: 191 additions & 0 deletions src/utils/evaluations/mintaka-wikidata/plot/validation_figures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT

from pathlib import Path

from functools import reduce
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

sns.set(rc={'figure.dpi': 300, 'savefig.dpi': 300})

ROOT = Path(__file__).parent
DATA = Path(ROOT.parent / 'data')
PLOTS = Path(ROOT / 'plots')

def percent_formatter(x):
return f'{'{:.1%}'.format(x/100)}'

CORRECTNESS_PALETTE = {"LinkQ 0/3": '#999999', "LinkQ 1/3": '#c8ddec', "LinkQ 2/3": '#72aad0', "LinkQ 3/3": '#1f78b4', "GPT-4 0/3": '#999999', "GPT-4 1/3": '#fff4e5', "GPT-4 2/3": '#ffdeb3', "GPT-4 3/3": '#fdbf6f'}
QUESTION_TYPE_ORDER = ['Comparative', 'Yes/No', 'Generic', 'MultiHop', "Intersection"]
PALETTE = {'LinkQ': '#1f78b4', 'GPT-4': '#fdbf6f'}
TO_REPLACE = {'multihop': 'MultiHop', 'generic': 'Generic', 'intersection': 'Intersection', 'yesno': 'Yes/No', 'comparative': 'Comparative'}

def accuracy_barchart_by_category():
# Load the data and rename certain columns and values
df = pd.read_csv(Path(DATA, 'aggregated-evaluation-results.csv'), usecols=['linkqAnswerCorrect', 'plainLLMAnswerCorrect', 'complexityType', 'category', 'id', 'question'])
df = df.rename(columns={'linkqAnswerCorrect': 'LinkQ', 'plainLLMAnswerCorrect': 'GPT-4', 'complexityType': 'Question Type'})
df = df.replace(to_replace=TO_REPLACE)

num_questions_per_type = len(df) // len(df['Question Type'].unique()) # Assumes same number of questions per category
df['LinkQ'] = (df['LinkQ'] > 0).astype(int)
df['GPT-4'] = (df['GPT-4'] > 0).astype(int)

# Unpivot the LinkQ and GPT-4 columns into Algorithm and Correctness columns
df = pd.melt(df, id_vars=['id', 'category', 'Question Type', 'question'], var_name='Algorithm', value_name='Correct')

# Count the correctness values and convert them into percentages
df = df.groupby(['Question Type', 'Algorithm']).agg({'Correct': 'sum'}).sort_values(by='Correct',ascending=False).reset_index()
df['Fraction'] = [f'{v}/{num_questions_per_type}' for v in df['Correct']]
df['% Correct'] = (df['Correct'] / num_questions_per_type) * 100

# Plot the data
ax = sns.barplot(df, x='Question Type', y='% Correct', order=['Comparative', 'Yes/No', 'Generic', 'MultiHop', "Intersection"], hue='Algorithm', hue_order=['LinkQ', 'GPT-4'], palette=PALETTE)

for container in ax.containers:
ax.bar_label(container, fmt=percent_formatter)
plt.savefig(Path(PLOTS, 'accuracy_barchart_by_category.pdf'), bbox_inches='tight', format='pdf')
plt.close()


def timing_boxplot_by_category():
# Load the data and rename certain columns and values
timing_columns = ['Total Seconds', 'id', 'complexityType', 'category']
linkq_df = pd.read_csv(Path(DATA, 'linkq-evaluation-results.csv'), usecols=timing_columns)
linkq_df['Algorithm'] = 'LinkQ'
plainllm_df = pd.read_csv(Path(DATA, 'plainllm-evaluation-results.csv'), usecols=timing_columns)
plainllm_df['Algorithm'] = 'GPT-4'
df = pd.concat([linkq_df, plainllm_df]).reset_index(drop=True)
df = df.rename(columns={'complexityType': 'Question Type'})
df = df.replace(to_replace=TO_REPLACE)

sns.boxplot(df, x='Question Type', y='Total Seconds', order=QUESTION_TYPE_ORDER, hue='Algorithm', palette=PALETTE)
plt.savefig(Path(PLOTS, 'timing_boxplot_by_category.pdf'), bbox_inches='tight', format='pdf')
plt.close()


def correctness_barchart():
# Load the data and rename certain columns and values
df = pd.read_csv(Path(DATA, 'aggregated-evaluation-results.csv'), usecols=['linkqAnswerCorrect', 'plainLLMAnswerCorrect', 'complexityType', 'category', 'id', 'question'])
df = df.rename(columns={'linkqAnswerCorrect': 'LinkQ', 'plainLLMAnswerCorrect': 'GPT-4', 'complexityType': 'Question Type'})
df = df.replace(to_replace=TO_REPLACE)
df['LinkQ'] = df['LinkQ'].apply(lambda x: f'LinkQ {x}/3')
df['GPT-4'] = df['GPT-4'].apply(lambda x: f'GPT-4 {x}/3')

# Assumes same number of questions per category
num_questions_per_type = len(df) // len(df['Question Type'].unique()) # Assumes same number of questions per category

# Unpivot the LinkQ and GPT-4 columns into Algorithm and Correctness columns
df = pd.melt(df, id_vars=['id', 'category', 'Question Type', 'question'], var_name='Algorithm', value_name='Correctness')

# Count the correctness values and convert them into percentages
df['Value'] = 0
df = df.groupby(['Question Type', 'Correctness']).agg(
{'Value': 'count'}).unstack(fill_value=0).stack(future_stack=True).reset_index()
df['Value'] = (df['Value'] / num_questions_per_type) * 100

# Plot the data
ax = sns.barplot(df, x='Question Type', y="Value", order=QUESTION_TYPE_ORDER, hue='Correctness',
hue_order=["LinkQ 3/3","GPT-4 3/3","LinkQ 2/3","GPT-4 2/3","LinkQ 1/3","GPT-4 1/3"],
palette=CORRECTNESS_PALETTE)

for container in ax.containers:
ax.bar_label(container, fmt=percent_formatter)
plt.savefig(Path(PLOTS, f'correctness.pdf'), bbox_inches='tight', format='pdf')
plt.close()


def correctness_stacked_barchart():
# Load the data and rename certain columns and values
df = pd.read_csv(Path(DATA, 'aggregated-evaluation-results.csv'), usecols=['linkqAnswerCorrect', 'plainLLMAnswerCorrect', 'complexityType', 'category', 'id', 'question'])
df = df.rename(columns={'linkqAnswerCorrect': 'LinkQ', 'plainLLMAnswerCorrect': 'GPT-4', 'complexityType': 'Question Type'})
df = df.replace(to_replace=TO_REPLACE)
df['LinkQ'] = df['LinkQ'].apply(lambda x: f'{x}/3')
df['GPT-4'] = df['GPT-4'].apply(lambda x: f'{x}/3')

# Custom sort the question types to keep all the plots consistent
df["Question Type"] = pd.Categorical(df["Question Type"], categories=QUESTION_TYPE_ORDER, ordered=True)
df = df.sort_values("Question Type")

# Assumes same number of questions per category
num_questions_per_type = len(df) // len(df['Question Type'].unique())

# Unpivot the LinkQ and GPT-4 columns into Algorithm and Correctness columns
df = pd.melt(df, id_vars=['id', 'category', 'Question Type', 'question'], var_name='Algorithm', value_name='Correctness')

# Count the correctness values and convert them into percentages
df['Value'] = 0
df = df.groupby(['Question Type', 'Algorithm', 'Correctness'],observed=False).agg(
{'Value': 'count'}).unstack(fill_value=0).stack(future_stack=True).reset_index()
df['Value'] = (df['Value'] / num_questions_per_type) * 100

# Prep the plot data
question_types = df['Question Type'].unique()
x = np.arange(len(question_types)) # X-axis positions for question_types
algorithms = ['LinkQ', 'GPT-4'] # this list determines left to right ordering of the algorithms
correctness = ['3/3','2/3','1/3'] # this list determines bottom to top stacking order of correctness
width = 0.38 # Width of the bar

# Plot side-by-side stacked bars
fig, ax = plt.subplots()
fig.set_figwidth(8)
for alg_idx, algorithm in enumerate(algorithms):
# Filter data for the current algorithm
algorithm_data = df[df['Algorithm'] == algorithm]
# Filter again by correctness
filtered_values = list(map(
lambda x: algorithm_data.loc[algorithm_data['Correctness'] == x]['Value'].reset_index(drop=True),
correctness))

plot_x = x + (alg_idx - 0.5) * width
bottom = np.zeros(len(question_types)) # The first correctness bars will be stacked from the bottom
# Loop over all the correctness to stack the bars on top of each other
for correct_idx, correct in enumerate(correctness):
values = filtered_values[correct_idx] # Series containing the values for this algorithm + correctness, by question type
color = CORRECTNESS_PALETTE[f'{algorithm} {correct}'] # Get the color palette for this algorithm + correctness
# Stack the bars for this correctness
bar = ax.bar(
x=plot_x,
height=values,
width=width,
color=color,
label=f'{algorithm} {correct}',
edgecolor="black",
linewidth=0.5,
bottom=bottom)

# for xpos, value, y in zip(plot_x, values, bottom):
# if value != 0.0:
# ax.text(x=xpos, y=y + value/2, s=percent_formatter(value), ha='center', va='center', fontsize=10)

# For the next set of stacked bars, we need to add these count values so we know where we should stack from
bottom += values

# Label the percentage sums
for xpos, total in zip(plot_x, bottom):
ax.text(x=xpos, y=total + 0.5, s=percent_formatter(total), ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Question Type')
ax.set_ylabel('% Correct')
# ax.set_title('Side-by-Side Stacked Bar Chart')
ax.set_xticks(x)
ax.set_xticklabels(question_types)
ax.legend(title="# Correct / 3 Attempts", title_fontsize=10, bbox_to_anchor=(1, 1), loc='upper left')
plt.grid(axis='x', which='both', visible=False)
plt.tight_layout()
plt.savefig(Path(PLOTS, 'correctness_stacked.pdf'), bbox_inches='tight', format='pdf')
plt.close()

def main():
PLOTS.mkdir(exist_ok=True)
accuracy_barchart_by_category()
timing_boxplot_by_category()
correctness_barchart()
correctness_stacked_barchart()
print("Done creating plots!")


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
//npx tsx prepMintakaQuestions.ts

import fs from "fs"
import papaparse from "papaparse"

import { MintakaQuestionType } from "./mintakaEvaluation";
import { parseCSVFile } from "utils/parseCSVFile";

prepMintakaQuestions()

Expand Down Expand Up @@ -86,15 +86,3 @@ export const QUESTIONS:MintakaQuestionType[] = ${JSON.stringify(filteredQuestion
fs.writeFileSync("./questions.ts",questionsFileContent)
console.log("Done prepping Mintaka questions!")
}

export function parseCSVFile<T>(path:string):Promise<T[]> {
return new Promise((resolve) => {
const file = fs.createReadStream(path)
papaparse.parse<T>(file, {
header: true,
complete: function(results) {
resolve(results.data)
}
})
})
}
12 changes: 0 additions & 12 deletions src/utils/evaluations/plot/README

This file was deleted.

Loading

0 comments on commit f015768

Please sign in to comment.