-
Notifications
You must be signed in to change notification settings - Fork 0
/
cacn.py
174 lines (134 loc) · 6.06 KB
/
cacn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# -*- coding: utf-8 -*-
"""cacn
Automatically generated by Colaboratory.
Author: Megha Sundriyal
Original file is located at
https://colab.research.google.com/drive/1x0CL7qbcfBiBiQ9jGBP85c3Y45UOQDeX
"""
import pandas as pd
import openai
import logging
import csv
import re
import nltk
nltk.download('punkt')
class CACN:
def __init__(self, api_key, prompt_file, data_file, output_file):
self.api_key = api_key
self.prompt = self.get_in_context_examples(prompt_file)
self.data = self.get_data(data_file)
def get_in_context_examples(self, prompt_file):
"""
Read the In-context examples from the given file
Args: file path
Returns: cleaned in-context examples to append with prompt.
"""
# Read the prompt from the text file and save as prompt
with open(prompt_file, 'r') as file:
prompt = file.read()
prompt = re.sub(r'[^a-zA-Z0-9.:?\s]', '', prompt)
prompt = prompt.replace('\t', ' ').replace('\n', ' ')
return prompt
def decontract(self, text):
"""
Decontract the contracted words in the given text.
Args:
text (str): Text containing contracted words.
Returns:
str: Text with contracted words expanded.
Example:
>>> decontract("I can't go. It's raining.")
"I cannot go. It is raining."
"""
contractions_dict = {}
with open('replacements.txt', 'r') as f:
for line in f:
key, value = line.strip().split(': ')
contractions_dict[key] = value
for contraction, expansion in contractions_dict.items():
pattern = re.compile(contraction, re.IGNORECASE)
text = re.sub(pattern, expansion, text)
return text
def clean_post(self, post):
# remove links
post = re.sub(r"http:\S+", "", post)
# remove special characters
post = re.sub(r"[\(\)#@!\^\\\/\+><]", "", post)
# remove extra white spaces
post = re.sub(r"\s+", " ", post)
# lower case
post = post.lower()
return post
def get_data(self, data_file):
"""
Read the data and resturn preprocssed posts.
Args: File path
Returns: Dataframe
Example:
>>> decontract("I can't go. It's too late.")
"I cannot go. It is too late."
"""
df = pd.read_csv(data_file)
# preprocess posts
df['clean post']= df['Social Media Post'].apply(self.decontract)
df['clean post'] = df['clean post'].apply(self.clean_post)
return df
def extract_claim(self, sentence):
"""
Extract normalized claim from the response generated by the model
Args:
sentence (str): The sentence generetaed by gpt
Returns:
str: Normalized claim
Example:
>>> extract_claim("The post claims that Thailand will ban Pfizer vaccines after a Thai princess falls into a coma following a booster jab. This claim is verifiable through various reports and has a huge social impact. Thus, the central claim here is Thailand will ban Pfizer vaccines after a Thai princess falls into a coma following a booster jab.")
"Thailand will ban Pfizer vaccines after a Thai princess falls into a coma following a booster jab."
"""
sentence = sentence.replace("U.S.", "US")
sentence = sentence.replace("Dr.", "Dr")
sentence = sentence.replace("Ph.D.", "PhD")
sentence = nltk.sent_tokenize(sentence)
claim = sentence[-1].strip() #get last sentence, which contains the central claim from the response.
# print(last_sentence)
pattern = r'(the central claim is|the normalized checkworthy claim is|the crucial checkworthy claim is|claim should be|claim here is|normalized checkworthy claim should be|central checkworthy claim is|normalized checkworthy claim should be:|he central claim in the post is|to be fact-checked here is)(?: that)?(.*?)[\.\n]'
matches = re.search(pattern, claim, flags=re.IGNORECASE)
if matches:
claim = matches.group(2).strip()
claim = claim.replace(":", "")
return claim
return claim
def generate_claims(self):
openai.api_key = self.api_key
SUMM_MAX_LENGTH = 120
MAX_TOKEN_LIMIT = 4096
with open(output_file, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Social Media post','CACN Normalized Claim','Gold Normalized Claim']) #create a new file and write headers
for instance in range(len(self.data)):
text = self.data['clean post'].iloc[instance]
prompt_text = f"{self.prompt} Identify the central claim in the given post: {text} \n Let's think step by step."
if len(prompt_text) > MAX_TOKEN_LIMIT:
text = text[:MAX_TOKEN_LIMIT]
prompt_text = f"{self.prompt} Identify the central claim in the given post: {text} \n Let's think step by step."
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt_text,
temperature=0.6,
max_tokens=SUMM_MAX_LENGTH,
top_p=1,
frequency_penalty=0.1,
stop=None
)
gpt_summary = response["choices"][0]["text"].strip()
normalized_claim = self.extract_claim(gpt_summary)
row = [self.data.iloc[instance]['Social Media Post'], normalized_claim, self.data.iloc[instance]['Normalized Claim']]
print(normalized_claim)
writer.writerow(row)
print(row)
if __name__ == "__main__":
api_key = <add-your-key-here>
prompt_file = 'prompt.txt'
data_file = 'CLAN-samples.csv'
output_file = 'output.csv'
claim_extractor = CACN(api_key, prompt_file, data_file, output_file)
claim_extractor.generate_claims()