-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrugGPT_eval.py
256 lines (215 loc) · 10.6 KB
/
drugGPT_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import pandas as pd
from tqdm import tqdm
import argparse
import os
from DrugGPT.src.ensemble.ensemble_model import EnsembleModel
from DrugGPT.src.prompt.prompt_manager import PromptManager
from DrugGPT.src.llama.llama_utils import LLaMAUtils, SoftEmbedding
from DrugGPT.src.gcn.dsdg import DSDGGenerator
from DrugGPT.src.utils.parser import binary_parser, text_parser, mc_parser
from DrugGPT.src.gcn.gcn_model import GraphConvolutionalNetwork
from DrugGPT.src.prompt_tuning.soft_prompt_tuning import SoftPromptTuner, extract_entities
import logging
import yaml
class Evaluation:
"""
A class for evaluating the performance of the EnsembleModel.
Attributes:
ensemble_model (EnsembleModel): The ensemble model to be evaluated.
parser_dict (dict): A dictionary of parsing functions for different datasets.
log_results (bool): Flag to enable logging of results.
store_results (bool): Flag to enable storing of results in a file.
log_wrong_answers_only (bool): Flag to log only the wrong answers.
use_openai (bool): Flag to use OpenAI API for inference.
Methods:
check_text_accuracy(prediction, actual): Checks if the prediction is accurate compared to the actual answer.
calculate_f1_metrics(prediction, label): Calculates F1 metrics for a given prediction and label.
log_answer(i, input_data, prediction, actual_label): Logs an answer with its index, input, prediction, and label.
evaluate(dataset_name, evaluation_set): Evaluates the model on a given dataset and returns the results.
"""
def __init__(self, ensemble_model, parser_dict, log_results=True, store_results=False,
log_wrong_answers_only=False, use_openai=False):
self.ensemble_model = ensemble_model
self.parser_dict = parser_dict
self.log_results = log_results
self.store_results = store_results
self.log_wrong_answers_only = log_wrong_answers_only
self.use_openai = use_openai
@staticmethod
def check_text_accuracy(prediction, actual):
prediction = prediction.rstrip('.')
prediction_initials = ''.join(word[0] for word in prediction.split())
actual_words = actual.split()
return all(word in prediction for word in actual_words) or prediction_initials == actual
@staticmethod
def calculate_f1_metrics(prediction, label):
prediction_set = set(prediction.split(', '))
label_set = set(label.split(', '))
true_positives = len(prediction_set & label_set)
false_positives = len(prediction_set - label_set)
false_negatives = len(label_set - prediction_set)
precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0
recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0
f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return precision, recall, f1_score
def log_answer(self, i, input_data, prediction, actual_label):
if self.log_wrong_answers_only and prediction.lower() == actual_label.lower():
return
logging.info(f"Index: {i} Question: {input_data}")
logging.info(f"Predicted: {prediction}, Actual: {actual_label}")
def evaluate(self, dataset_name, evaluation_set):
print(f"\nEvaluating: {dataset_name}")
slice_size = len(evaluation_set['sample'])
accurate_predictions = 0
precision_list = []
recall_list = []
f1_list = []
wrong_answers = []
for i in tqdm(range(slice_size), desc="Processing"):
input_data = evaluation_set['sample'][i]
full_response = self.ensemble_model.run_inference(input_data, use_openai=self.use_openai)
parsed_response = self.parser_dict[dataset_name](full_response)
correct_answer = self.check_text_accuracy(parsed_response, evaluation_set['label'][i].lower())
if correct_answer:
accurate_predictions += 1
else:
wrong_answers.append((i, input_data, parsed_response, evaluation_set['label'][i]))
if dataset_name == 'chatDoctor':
precision, recall, f1 = self.calculate_f1_metrics(parsed_response, evaluation_set['label'][i].lower())
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
self.log_answer(i, input_data, parsed_response, evaluation_set['label'][i])
results = {
'Accuracy': accurate_predictions / slice_size
}
if dataset_name == 'chatDoctor':
results['Average Precision'] = sum(precision_list) / len(precision_list)
results['Average Recall'] = sum(recall_list) / len(recall_list)
results['Average F1 Score'] = sum(f1_list) / len(f1_list)
if self.store_results:
df = pd.DataFrame(wrong_answers, columns=['Index', 'Question', 'Predicted', 'Actual'])
df.to_csv(f'evaluation_wrong_answers_{dataset_name}.csv', index=False)
return results
def load_evaluation_set(dataset_name):
"""
Load the evaluation set for a given dataset.
Args:
dataset_name (str): Name of the dataset to load.
Returns:
DataFrame: The loaded dataset.
"""
# Dictionary mapping dataset names to their file paths
data_paths_dict = {
'pubmedqa': {
'type': 'binary',
'data': f'../../data/pubmedqa_data.csv',
'answer': f'../../data/pubmedqa_answer.csv'
},
'ade': {
'type': 'text',
'data': f'../../data/ade_data.csv',
'answer': f'../../data/ade_answer.csv'
},
'chatDoctor': {
'type': 'text',
'data': f'../../data/chatDoctor_data.csv',
'answer': f'../../data/chatDoctor_answer.csv'
},
'DDI_binary': {
'type': 'binary',
'data': f'../../data/DDI_binary_data.csv',
'answer': f'../../data/DDI_binary_answer.csv'
},
'drug_usage': {
'type': 'text',
'data': f'../../data/drug_usage_data.csv',
'answer': f'../../data/drug_usage_answer.csv'
},
'medmcqa': {
'type': 'mc',
'data': f'../../data/medmcqa_data.csv',
'answer': f'../../data/medmcqa_answer.csv'
},
'mmlu_mc': {
'type': 'mc',
'data': f'../../data/mmlu_mc_data.csv',
'answer': f'../../data/mmlu_mc_answer.csv'
},
'usmle_mc': {
'type': 'mc',
'data': f'../../data/usmle_mc_data.csv',
'answer': f'../../data/usmle_mc_answer.csv'
},
'moderna_interactions': {
'type': 'binary',
'data': f'../../data/moderna_interactions_data.csv',
'answer': f'../../data/moderna_interactions_answer.csv'
}
}
# Get the paths for the specified dataset
dataset_paths = data_paths_dict.get(dataset_name, {})
if not dataset_paths:
raise ValueError(f"Dataset {dataset_name} not found in data paths dictionary")
# Load the dataset
data_df = pd.read_csv(dataset_paths['data'])
answer_df = pd.read_csv(dataset_paths['answer'])
# Merge data and answers into a single DataFrame (assuming they can be merged directly)
evaluation_set = pd.merge(data_df, answer_df, on='some_common_column')
return evaluation_set
def main():
# Load configurations from model.yaml
with open('model.yaml', 'r') as file:
configs = yaml.safe_load(file)
LLAMA_CONFIGS = configs['LLAMA_CONFIGS']
GCN_CONFIGS = configs['GCN_CONFIGS']
# Initialize components
prompt_manager = PromptManager()
dsdg_generator = DSDGGenerator(args.excel_path, embd_model_name='all-MiniLM-L6-v2')
llama_utils = LLaMAUtils(LLAMA_CONFIGS)
gcn_model = GraphConvolutionalNetwork(GCN_CONFIGS['input_dim'], GCN_CONFIGS['hidden_dim'],
GCN_CONFIGS['output_dim'])
# Initialize Ensemble Model with GCN as soft prompt generator
ensemble_model = EnsembleModel(prompt_manager, gcn_model, dsdg_generator, llama_utils, args.openai_key)
# Define parser dictionary for different datasets
parser_dict = {
'pubmedqa': binary_parser,
'ade': text_parser,
'chatDoctor': text_parser,
'DDI_binary': binary_parser,
'drug_usage': text_parser,
'medmcqa': mc_parser,
'mmlu_mc': mc_parser,
'usmle_mc': mc_parser,
'moderna_interactions': binary_parser
}
# Load evaluation dataset
evaluation_set = load_evaluation_set(args.evaluation_set_path)
# Initialize and run evaluation
evaluator = Evaluation(ensemble_model, args.log_results, args.store_results, args.log_wrong_answers_only,
args.use_open_ai)
results = evaluator.evaluate(args.dataset_name, evaluation_set, parser_dict)
# Log and store results
logging.info(f"Evaluation results: {results}")
if args.store_results:
with open(f"evaluation_results_{args.dataset_name}.txt", 'w') as file:
file.write(str(results))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate Ensemble Model")
parser.add_argument('--openai_key', type=str, required=True, help='OpenAI API key')
parser.add_argument('--hf_key', type=str, required=True, help='Hugging Face API key')
parser.add_argument('--excel_path', type=str, required=True, help='Path to DSDG Excel file')
parser.add_argument('--dataset_name', type=str, required=True,
choices=['pubmedqa', 'ade', 'chatDoctor', 'DDI_binary', 'drug_usage', 'medmcqa', 'mmlu_mc',
'usmle_mc', 'moderna_interactions'], help='Name of the dataset for evaluation')
parser.add_argument('--evaluation_set_path', type=str, required=True, help='Path to the evaluation dataset')
parser.add_argument('--log_results', action='store_true', help='Enable logging of results')
parser.add_argument('--store_results', action='store_true', help='Enable storing of results')
parser.add_argument('--log_wrong_answers_only', action='store_true', help='Log only wrong answers')
parser.add_argument('--use_open_ai', action='store_true', help='Use OpenAI API for inference of last generation '
'model for better conversational alignment')
args = parser.parse_args()
# Set environment variables for API keys
os.environ["OPENAI_API_KEY"] = args.openai_key
os.environ["HUGGINGFACEHUB_API_TOKEN"] = args.hf_key
main()