-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathgoldfish_inference.py
62 lines (56 loc) · 2.61 KB
/
goldfish_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import argparse
import gradio as gr
from goldfish_lv import GoldFish_LV
from theme import minigptlv_style
import time
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def get_arguments():
parser = argparse.ArgumentParser(description="Inference parameters")
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
parser.add_argument("--neighbours", type=int, default=3)
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
parser.add_argument("--add_subtitles", action='store_true')
parser.add_argument("--max_new_tokens", type=int, default=512)
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for short video clips")
parser.add_argument("--lora_r", type=int, default=64)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--video_path", type=str,default="path for video.mp4", help="Path to the video file or youtube url")
parser.add_argument("--question", type=str, default="Why rachel is wearing a wedding dress?")
parser.add_argument("--options", nargs="+")
return parser.parse_args()
def download_video(youtube_url):
processed_video_path = goldfish_lv.process_video_url(youtube_url)
return processed_video_path
def process_video(video_path, has_subtitles, instruction="",number_of_neighbours=-1):
result = goldfish_lv.inference(video_path, has_subtitles, instruction,number_of_neighbours)
pred = result["pred"]
return pred
def return_video_path(youtube_url):
video_id = youtube_url.split("https://www.youtube.com/watch?v=")[-1].split('&')[0]
if video_id:
return os.path.join("workspace", "tmp", f"{video_id}.mp4")
else:
raise ValueError("Invalid YouTube URL provided.")
args=get_arguments()
if __name__ == "__main__":
t1=time.time()
print("using openai: ", args.use_openai_embedding)
goldfish_lv = GoldFish_LV(args)
t2=time.time()
print("Time taken to load model: ", t2-t1)
processed_video_path = goldfish_lv.process_video_url(args.video_path)
pred=process_video(processed_video_path, args.add_subtitles, args.question,args.neighbours)
print("Question answer: ", pred)
print(f"Time taken for inference: ", time.time()-t2)