diff options
author | Anthony Wang | 2024-07-26 23:03:29 -0500 |
---|---|---|
committer | Anthony Wang | 2024-07-26 23:03:29 -0500 |
commit | 088075b13330be9730e339d9eaf01e8ab468d6c3 (patch) | |
tree | b90bbca4b68c8bbe8ca9c2dc69a6bab2df5816e9 /model.py |
Initial commit
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/model.py b/model.py new file mode 100644 index 0000000..d2e765d --- /dev/null +++ b/model.py @@ -0,0 +1,38 @@ +import torch +from imagebind import data +from imagebind.models import imagebind_model +from imagebind.models.imagebind_model import ModalityType + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + +print("Loading model") +model = imagebind_model.imagebind_huge(pretrained=True) +model.eval() +model.to(device) + + +def embed(mode, data): + with torch.no_grad(): + return model({mode: data})[mode][0] + + +def embed_text(text): + return embed(ModalityType.TEXT, data.load_and_transform_text([text], device)) + + +def embed_audio(audio_path): + return embed( + ModalityType.AUDIO, data.load_and_transform_audio_data([audio_path], device) + ) + + +def embed_image(image_path): + return embed( + ModalityType.VISION, data.load_and_transform_vision_data([image_path], device) + ) + + +def embed_video(video_path): + return embed( + ModalityType.VISION, data.load_and_transform_video_data([video_path], device) + ) |