VOOZH about

URL: https://huggingface.co/kye/GigaBind

⇱ kye/GigaBind · Hugging Face


GigaBind

A finetuned ImageBind using Lora for images, audio, and many many other modalitiesi

Usage

import logging
import torch
import data

from models import imagebind_model
from models.imagebind_model import ModalityType, load_module
from models import lora as LoRA

logging.basicConfig(level=logging.INFO, force=True)


lora = True
linear_probing = False
device = "cpu" # "cuda:0" if torch.cuda.is_available() else "cpu"
load_head_post_proc_finetuned = True

assert not (linear_probing and lora), \
 "Linear probing is a subset of LoRA training procedure for ImageBind. " \
 "Cannot set both linear_probing=True and lora=True. "

if lora and not load_head_post_proc_finetuned:
 # Hack: adjust lora_factor to the `max batch size used during training / temperature` to compensate missing norm
 lora_factor = 12 / 0.07
else:
 # This assumes proper loading of all params but results in shift from original dist in case of LoRA
 lora_factor = 1

text_list=["bird",
 "car",
 "dog3",
 "dog5",
 "dog8",
 "grey_sloth_plushie"]
image_paths=[".assets/bird_image.jpg",
 ".assets/car_image.jpg",
 ".assets/dog3.jpg",
 ".assets/dog5.jpg",
 ".assets/dog8.jpg",
 ".assets/grey_sloth_plushie.jpg"]
audio_paths=[".assets/bird_audio.wav",
 ".assets/car_audio.wav",
 ".assets/dog_audio.wav"]

# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
if lora:
 model.modality_trunks.update(
 LoRA.apply_lora_modality_trunks(model.modality_trunks, rank=4,
 layer_idxs={ModalityType.TEXT: [0, 1, 2, 3, 4, 5, 6, 7, 8],
 ModalityType.VISION: [0, 1, 2, 3, 4, 5, 6, 7, 8]},
 modality_names=[ModalityType.TEXT, ModalityType.VISION]))

 # Load LoRA params if found
 LoRA.load_lora_modality_trunks(model.modality_trunks,
 checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last")

 if load_head_post_proc_finetuned:
 # Load postprocessors & heads
 load_module(model.modality_postprocessors, module_name="postprocessors",
 checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last")
 load_module(model.modality_heads, module_name="heads",
 checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last")
elif linear_probing:
 # Load heads
 load_module(model.modality_heads, module_name="heads",
 checkpoint_dir="./.checkpoints/lora/500_epochs_lp", postfix="_dreambooth_last")

model.eval()
model.to(device)

# Load data
inputs = {
 ModalityType.TEXT: data.load_and_transform_text(text_list, device),
 ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device, to_tensor=True),
 ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}

with torch.no_grad():
 embeddings = model(inputs)

print(
 "Vision x Text: ",
 torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T * (lora_factor if lora else 1), dim=-1),
)
print(
 "Audio x Text: ",
 torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T * (lora_factor if lora else 1), dim=-1),
)
print(
 "Vision x Audio: ",
 torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support