diff --git a/boost_qa/model.py b/boost_qa/model.py new file mode 100644 index 0000000..2beeae1 --- /dev/null +++ b/boost_qa/model.py @@ -0,0 +1,44 @@ +### Install mPLUG-Owl from https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2#install + +import torch.nn as nn +import torch + +from typing import List +from PIL import Image + +from mplug_owl2.model.builder import load_pretrained_model + +from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + + +class QInstructScorer(nn.Module): + def __init__(self, boost=True, device="cuda:0"): + super().__init__() + tokenizer, model, image_processor, _ = load_pretrained_model("teowu/mplug_owl2_7b_448_qinstruct_preview_v0.2", None, "mplug_owl2", device=device) + prompt = "USER: <|image|>Rate the quality of the image.\nASSISTANT: " + + if not boost: + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["good", "average", "poor"])["input_ids"]] + self.weight_tensor = torch.Tensor([1, 0.5, 0]).half().to(model.device) + else: + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["good", "average", "poor", "high", "medium", "low", "fine", "acceptable", "bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1, 0.5, 0, 1, 0.5, 0, 1, 0.5, 0]).half().to(model.device) / 3. + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def forward(self, image: List[Image.Image]): + with torch.inference_mode(): + image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device) + output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1, self.preferential_ids_] + + return torch.softmax(output_logits, -1) @ self.weight_tensor + + +if __name__ == "__main__": + scorer = QInstructScorer(boost=False) + print(scorer([Image.open("fig/examples_211.jpg"),Image.open("fig/sausage.jpg")]).tolist()) \ No newline at end of file