-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarizer.py
326 lines (277 loc) · 11.3 KB
/
summarizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
import html2text
import json5
import yaml
from cornsnake import (
util_print,
util_file,
util_time,
util_wait,
util_dir,
)
import config
import extractor
import prompts
import util_chat
import util_config
def _clean_response(text):
prelim_with_data_format = f"```{prompts.get_output_format_name().lower()}"
if prelim_with_data_format in text:
text = text.split(prelim_with_data_format)[1]
delimit = "```"
if delimit in text:
parts = text.split(delimit)
max_len = 0
selected = ""
for part in parts:
if len(part) > max_len:
selected = part
max_len = len(part)
return selected
return text
def _summarize_with_retry(prompt):
retries_remaining = config.RETRY_COUNT
rsp_parsed = None
elapsed_seconds = 0
total_cost = 0.0
while not rsp_parsed and retries_remaining > 0:
rsp = None
try:
(rsp, _elapsed_seconds, cost) = util_chat.next_prompt(prompt)
elapsed_seconds += _elapsed_seconds
total_cost += cost
rsp = _clean_response(rsp)
rsp_parsed = None
if util_config.is_json_not_yaml():
rsp_parsed = json5.loads(rsp) # a bit more robust than json package
else:
rsp_parsed = yaml.safe_load(rsp)
except Exception as error:
util_print.print_error("Error parsing response")
util_print.print_error(error)
if config.is_debug:
print("REQ: ", prompt)
print("RSP: ", rsp)
if util_config.is_openai():
util_wait.wait_seconds(config.RETRY_WAIT_SECONDS)
retries_remaining -= 1
if retries_remaining:
util_print.print_warning("Retrying...")
if rsp_parsed is None:
util_print.print_error(f"!!! RETRIES EXPIRED !!!")
return (rsp_parsed, elapsed_seconds, total_cost)
def _divide_into_chunks(list, size):
for i in range(0, len(list), size):
yield list[i : i + size]
def _get_path_to_output_file(path_to_input_file, path_to_output_dir):
if not path_to_output_dir:
return None
input_filename = util_file.get_last_part_of_path(path_to_input_file)
output_filename = util_file.change_extension(
input_filename, ".yaml.txt"
) # adding .txt so can preview in Windows Explorer, Dropbox etc.
path_to_output_file = os.path.join(path_to_output_dir, output_filename)
return path_to_output_file
def _write_output_file(
short_summary, long_summary, paragraphs, elapsed_seconds, cost, path_to_output_file
):
file_result = {}
file_result["short_summary"] = short_summary
file_result["long_summary"] = long_summary
file_result["paragraphs"] = paragraphs
file_result["total_time_seconds"] = elapsed_seconds
file_result["total_estimated_cost_currency"] = config.OPENAI_COST_CURRENCY
file_result["total_estimated_cost"] = cost
yaml_text = yaml.dump(file_result)
util_print.print_important(f"Writing YAML file to {path_to_output_file}")
util_file.write_text_to_file(yaml_text, path_to_output_file)
def _chunk_text_by_words(input_text):
input_words = input_text.split(" ")
input_words_count = len(input_words)
input_text_list = []
if input_words_count > config.MAIN_INPUT_WORDS:
util_print.print_warning(
f"The input file has many words! Max is {config.MAIN_INPUT_WORDS} but that file has {input_words_count} words. Will chunk the text."
)
chunks = _divide_into_chunks(input_words, config.MAIN_INPUT_WORDS)
input_text_list = []
for chunk in chunks:
input_text_list.append(" ".join(chunk))
print(f"Split into {len(input_text_list)} chunks")
else:
input_text_list = [input_text]
return input_text_list
def _print_file_result(short_summary, long_summary, paragraphs, elapsed_seconds, cost, chunk_count, chunks_failed):
util_print.print_section("FULL Short Summary")
print(short_summary)
util_print.print_section("FULL Long Summary")
print(long_summary)
util_print.print_section("FULL paragraphs Summary")
print("\n".join(paragraphs))
util_print.print_result(
f" -- THIS FILE time: {util_time.describe_elapsed_seconds(elapsed_seconds)}"
)
if cost > 0:
util_print.print_important(
f" -- THIS FILE estimated cost: {config.OPENAI_COST_CURRENCY}{cost}"
)
if chunks_failed > 0:
util_print.print_warning(f"{chunks_failed} of {chunk_count} document chunks were skipped. If the summary is not of high quality, you can re-run with smaller chunks, by reducing MAIN_INPUT_WORDS in config.py.")
def _extract_text(path_to_input_file):
input_text = util_file.read_text_from_text_or_pdf_file_skipping_comments(
path_to_input_file
)
if path_to_input_file.endswith(".html"):
return html2text.html2text(input_text)
return input_text
def _convert_array_to_str(a):
"""
Occasionally LLM can return a dict where str was expected
"""
if isinstance(a, str):
return a
if isinstance(a, dict):
util_print.print_warning("Unexpected response format: Converting dict to str")
return yaml.dump(a)
if isinstance(a, list):
util_print.print_warning("Unexpected response format: Converting list to str")
return yaml.dump(a)
return a
def _convert_array_of_dict_to_array(a_list):
if isinstance(a_list, list):
new_list = []
for a1 in a_list:
if isinstance(a1, dict):
new_list.append(yaml.dump(a1))
elif isinstance(a1, list):
new_list += a1
else:
new_list.append(a1)
return new_list
elif isinstance(a1, dict):
return yaml.dump(a1)
return a_list
def _summarize_one_file(path_to_input_file, target_language, path_to_output_dir):
util_print.print_section(f"Summarizing '{path_to_input_file}'")
input_text = _extract_text(path_to_input_file)
input_text_chunks = _chunk_text_by_words(input_text)
if target_language is None:
print(f"Summarizing file at '{path_to_input_file}'...")
else:
print(f"Summarizing file at '{path_to_input_file}' into {target_language}...")
short_summary = ""
long_summary = ""
paragraphs = []
elapsed_seconds = 0
cost = 0.0
chunks_failed = 0
chunk_count = 1
for text in input_text_chunks:
prompt = ""
if util_config.is_local_via_ctransformers():
# TODO try fix
if target_language is not None:
raise (f"target_language is only supported when using Open AI ChatGPT")
prompt = prompts.get_simple_summarize_prompt(text)
if config.LOCAL_CTRANSFORMERS_MODEL_TYPE == "llama":
prompt = prompts.get_llama_summarize_prompt(text)
(response_plain, _elapsed_seconds) = _summarize_with_retry(prompt)
elapsed_seconds += _elapsed_seconds
rsp = {"short_summary": response_plain}
elif util_config.is_local_via_ollama():
if target_language is None:
prompt = prompts.get_ollama_summarize_prompt(text)
else:
prompt = prompts.get_ollama_summary_prompt_and_translate_to(
text, target_language
)
(rsp, _elapsed_seconds, _cost) = _summarize_with_retry(prompt)
elapsed_seconds += _elapsed_seconds
cost += _cost
elif util_config.is_openai():
if target_language is None:
prompt = prompts.get_chatgpt_summarize_prompt(text)
else:
prompt = prompts.get_chatgpt_summary_prompt_and_translate_to(
text, target_language
)
(rsp, _elapsed_seconds, _cost) = _summarize_with_retry(prompt)
elapsed_seconds += _elapsed_seconds
cost += _cost
else:
raise ValueError(
"Please check config.py - one of openai, local via ctransformers OR ollama should be enabled."
)
util_print.print_section(
f"Short Summary = Chunk {chunk_count} of {len(input_text_chunks)}"
)
if rsp is None:
chunks_failed += 1
else:
if isinstance(rsp, str):
util_print.print_warning("Response is string - expected dict")
print(rsp)
short_summary += rsp + "\n"
else:
if "short_summary" in rsp:
print(rsp["short_summary"])
short_summary += _convert_array_to_str(rsp["short_summary"]) + "\n"
if "long_summary" in rsp:
long_summary += _convert_array_to_str(rsp["long_summary"]) + "\n"
if "paragraphs" in rsp:
paragraphs += _convert_array_of_dict_to_array(rsp["paragraphs"])
chunk_count += 1
_print_file_result(short_summary, long_summary, paragraphs, elapsed_seconds, cost, len(input_text_chunks), chunks_failed)
path_to_output_file = _get_path_to_output_file(
path_to_input_file, path_to_output_dir
)
if path_to_output_file:
_write_output_file(
short_summary,
long_summary,
paragraphs,
elapsed_seconds,
cost,
path_to_output_file,
)
return (elapsed_seconds, cost)
def _print_final_result(files_processed, elapsed_seconds, files_skipped, cost):
util_print.print_section("Completed")
util_print.print_result(
f"{files_processed} files processed in {util_time.describe_elapsed_seconds(elapsed_seconds)}"
)
if files_skipped > 0:
util_print.print_result(f"{files_skipped} files skipped")
if cost > 0:
util_print.print_important(
f" -- Total estimated cost: {config.OPENAI_COST_CURRENCY}{cost}"
)
def summarize_file_or_dir_or_url(
path_to_input_file_or_dir_or_url, path_to_output_dir, target_language
):
if path_to_output_dir:
util_dir.ensure_dir_exists(path_to_output_dir)
input_filepaths = extractor.collect_input_filepaths(
path_to_input_file_or_dir_or_url
)
files_processed = 0
files_skipped = 0
elapsed_seconds = 0
cost = 0
for path_to_input_file in input_filepaths:
path_to_output_file = _get_path_to_output_file(
path_to_input_file, path_to_output_dir
)
if path_to_output_file and os.path.exists(path_to_output_file):
util_print.print_warning(
f"[skipping] output file '{path_to_output_file}' already exists"
)
files_skipped += 1
continue
(_elapsed_seconds, _cost) = _summarize_one_file(
path_to_input_file, target_language, path_to_output_dir
)
elapsed_seconds += _elapsed_seconds
cost += _cost
files_processed += 1
_print_final_result(files_processed, elapsed_seconds, files_skipped, cost)