Skip to content

Commit

Permalink
iqa_model
Browse files Browse the repository at this point in the history
  • Loading branch information
haoning.wu committed Dec 13, 2023
1 parent 4a163b8 commit c3627ab
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions boost_qa/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
### Install mPLUG-Owl from https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2#install

import torch.nn as nn
import torch

from typing import List
from PIL import Image

from mplug_owl2.model.builder import load_pretrained_model

from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria


class QInstructScorer(nn.Module):
def __init__(self, boost=True, device="cuda:0"):
super().__init__()
tokenizer, model, image_processor, _ = load_pretrained_model("teowu/mplug_owl2_7b_448_qinstruct_preview_v0.2", None, "mplug_owl2", device=device)
prompt = "USER: <|image|>Rate the quality of the image.\nASSISTANT: "

if not boost:
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["good", "average", "poor"])["input_ids"]]
self.weight_tensor = torch.Tensor([1, 0.5, 0]).half().to(model.device)
else:
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["good", "average", "poor", "high", "medium", "low", "fine", "acceptable", "bad"])["input_ids"]]
self.weight_tensor = torch.Tensor([1, 0.5, 0, 1, 0.5, 0, 1, 0.5, 0]).half().to(model.device) / 3.

self.tokenizer = tokenizer
self.model = model
self.image_processor = image_processor
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

def forward(self, image: List[Image.Image]):
with torch.inference_mode():
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
images=image_tensor)["logits"][:,-1, self.preferential_ids_]

return torch.softmax(output_logits, -1) @ self.weight_tensor


if __name__ == "__main__":
scorer = QInstructScorer(boost=False)
print(scorer([Image.open("fig/examples_211.jpg"),Image.open("fig/sausage.jpg")]).tolist())

0 comments on commit c3627ab

Please sign in to comment.