import torch from imagebind import data from imagebind.models import imagebind_model from imagebind.models.imagebind_model import ModalityType device = "cuda:0" if torch.cuda.is_available() else "cpu" print("Loading model") model = imagebind_model.imagebind_huge(pretrained=True) model.eval() model.to(device) def embed(mode, data): with torch.no_grad(): return model({mode: data})[mode][0] def embed_text(text): return embed(ModalityType.TEXT, data.load_and_transform_text([text], device)) def embed_audio(audio_path): return embed( ModalityType.AUDIO, data.load_and_transform_audio_data([audio_path], device) ) def embed_image(image_path): return embed( ModalityType.VISION, data.load_and_transform_vision_data([image_path], device) ) def embed_video(video_path): return embed( ModalityType.VISION, data.load_and_transform_video_data([video_path], device) )