Fix aesthetic score (again), add llava reward
This commit is contained in:
@@ -30,22 +30,22 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class AestheticScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, dtype):
|
||||
super().__init__()
|
||||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.mlp = MLP()
|
||||
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
|
||||
self.mlp.load_state_dict(state_dict)
|
||||
self.dtype = dtype
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, images):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], Image.Image)
|
||||
device = next(self.parameters()).device
|
||||
inputs = self.processor(images=images, return_tensors="pt")
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
|
||||
embed = self.clip.get_image_features(**inputs)
|
||||
# normalize embedding
|
||||
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
||||
return self.mlp(embed)
|
||||
return self.mlp(embed).squeeze(1)
|
||||
|
@@ -32,14 +32,145 @@ def jpeg_compressibility():
|
||||
def aesthetic_score():
|
||||
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
|
||||
|
||||
scorer = AestheticScorer().cuda()
|
||||
scorer = AestheticScorer(dtype=torch.float32).cuda()
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
images = [Image.fromarray(image) for image in images]
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
scores = scorer(images)
|
||||
return scores, {}
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
def llava_strict_satisfaction():
|
||||
"""Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without
|
||||
using BERTScore. Prompt metadata must have "questions" and "answers" keys. See
|
||||
https://github.com/kvablack/LLaVA-server for server-side code.
|
||||
"""
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from io import BytesIO
|
||||
import pickle
|
||||
|
||||
batch_size = 4
|
||||
url = "http://127.0.0.1:8085"
|
||||
sess = requests.Session()
|
||||
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
|
||||
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
del prompts
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
|
||||
metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size))
|
||||
|
||||
all_scores = []
|
||||
all_info = {
|
||||
"answers": [],
|
||||
}
|
||||
for image_batch, metadata_batch in zip(images_batched, metadata_batched):
|
||||
jpeg_images = []
|
||||
|
||||
# Compress the images using JPEG
|
||||
for image in image_batch:
|
||||
img = Image.fromarray(image)
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=80)
|
||||
jpeg_images.append(buffer.getvalue())
|
||||
|
||||
# format for LLaVA server
|
||||
data = {
|
||||
"images": jpeg_images,
|
||||
"queries": [m["questions"] for m in metadata_batch],
|
||||
}
|
||||
data_bytes = pickle.dumps(data)
|
||||
|
||||
# send a request to the llava server
|
||||
response = sess.post(url, data=data_bytes, timeout=120)
|
||||
|
||||
response_data = pickle.loads(response.content)
|
||||
|
||||
correct = np.array(
|
||||
[
|
||||
[ans in resp for ans, resp in zip(m["answers"], responses)]
|
||||
for m, responses in zip(metadata_batch, response_data["outputs"])
|
||||
]
|
||||
)
|
||||
scores = correct.mean(axis=-1)
|
||||
|
||||
all_scores += scores.tolist()
|
||||
all_info["answers"] += response_data["outputs"]
|
||||
|
||||
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
def llava_bertscore():
|
||||
"""Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
|
||||
https://github.com/kvablack/LLaVA-server for server-side code.
|
||||
"""
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from io import BytesIO
|
||||
import pickle
|
||||
|
||||
batch_size = 16
|
||||
url = "http://127.0.0.1:8085"
|
||||
sess = requests.Session()
|
||||
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
|
||||
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
del metadata
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
|
||||
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
|
||||
|
||||
all_scores = []
|
||||
all_info = {
|
||||
"precision": [],
|
||||
"f1": [],
|
||||
"outputs": [],
|
||||
}
|
||||
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
|
||||
jpeg_images = []
|
||||
|
||||
# Compress the images using JPEG
|
||||
for image in image_batch:
|
||||
img = Image.fromarray(image)
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=80)
|
||||
jpeg_images.append(buffer.getvalue())
|
||||
|
||||
# format for LLaVA server
|
||||
data = {
|
||||
"images": jpeg_images,
|
||||
"queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch),
|
||||
"answers": [[f"The image contains {prompt}"] for prompt in prompt_batch],
|
||||
}
|
||||
data_bytes = pickle.dumps(data)
|
||||
|
||||
# send a request to the llava server
|
||||
response = sess.post(url, data=data_bytes, timeout=120)
|
||||
|
||||
response_data = pickle.loads(response.content)
|
||||
|
||||
# use the recall score as the reward
|
||||
scores = np.array(response_data["recall"]).squeeze()
|
||||
all_scores += scores.tolist()
|
||||
|
||||
# save the precision and f1 scores for analysis
|
||||
all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist()
|
||||
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
|
||||
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
|
||||
|
||||
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
|
||||
|
||||
return _fn
|
||||
|
Reference in New Issue
Block a user