aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAnthony Wang2024-07-26 23:03:29 -0500
committerAnthony Wang2024-07-26 23:03:29 -0500
commit088075b13330be9730e339d9eaf01e8ab468d6c3 (patch)
treeb90bbca4b68c8bbe8ca9c2dc69a6bab2df5816e9 /model.py
Initial commit
Diffstat (limited to 'model.py')
-rw-r--r--model.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..d2e765d
--- /dev/null
+++ b/model.py
@@ -0,0 +1,38 @@
+import torch
+from imagebind import data
+from imagebind.models import imagebind_model
+from imagebind.models.imagebind_model import ModalityType
+
+device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+print("Loading model")
+model = imagebind_model.imagebind_huge(pretrained=True)
+model.eval()
+model.to(device)
+
+
+def embed(mode, data):
+ with torch.no_grad():
+ return model({mode: data})[mode][0]
+
+
+def embed_text(text):
+ return embed(ModalityType.TEXT, data.load_and_transform_text([text], device))
+
+
+def embed_audio(audio_path):
+ return embed(
+ ModalityType.AUDIO, data.load_and_transform_audio_data([audio_path], device)
+ )
+
+
+def embed_image(image_path):
+ return embed(
+ ModalityType.VISION, data.load_and_transform_vision_data([image_path], device)
+ )
+
+
+def embed_video(video_path):
+ return embed(
+ ModalityType.VISION, data.load_and_transform_video_data([video_path], device)
+ )