Custom embeddings

Shows how to upload custom embeddings to improve similarity search.

How to upload custom embeddings

Custom embeddings improve data exploration by improving similarity search.

You can upload up to ten (10) custom embedding types per workspace on any data type.

Use this to experiment with different embeddings to improve data selection.

Before you start

This example requires the following libraries:

# Starting from SDK version 3.69, custom embeddings are now supported.
import labelbox as lb
import numpy as np
import json
import uuid
import random

Replace API key

API_KEY = ""
client = lb.Client(API_KEY)

Select data rows

First, we need to fetch data rows from a Labelbox dataset.

To improve similarity search, you need to upload custom embeddings to at least 1,000 data rows.

dataset = client.get_dataset("<DATASET-ID>")

export_task = dataset.export()
export_task.wait_till_done()

data_rows = []

# Stream results and errors
if export_task.has_errors():
    export_task.get_buffered_stream(stream_type=lb.StreamType.ERRORS).start(
        stream_handler=lambda error: print(error))

if export_task.has_result():
    # Start export stream
    stream = export_task.get_buffered_stream()

    # Iterate through data rows
    for data_row in stream:
        print(data_row.json)

Extract the data row ID and the row data (asset URL):

data_row_dict = [{"data_row_id": dr["data_row"]["id"]} for dr in data_rows]
data_row_dict = data_row_dict[:1000] # keep the first 1000 examples for the sake of this demo

Create custom embedding payload

To prepare the data:

  1. Generate random vectors for embeddings (max: 2048 dimensions)

    nb_data_rows = len(data_row_dict)
    print("Number of data rows: ", nb_data_rows)
    # Labelbox supports custom embedding vectors of up to 2048 dimensions
    custom_embeddings = [list(np.random.random(2048)) for _ in range(nb_data_rows)]
    
    
  2. List custom embeddings in your Labelbox workspace:

    embeddings = client.get_embeddings()
    
  3. Choose an existing embedding type or create a new one
    A unique custom embedding name is required as an argument for this method.

    # Name of the custom embedding must be unique
    embedding = client.create_embedding("my_custom_embedding_2048_dimensions", 2048)
    
  4. Create payload

  • The payload should encompass the key (data row id or global key) and the new embedding vector data. Note that the dataset.upsert_data_rows() operation will only update the values you pass in the payload; all other existing row data will not be modified.

    payload = []
    for data_row_dict, custom_embedding in zip(data_row_dict,custom_embeddings):
      payload.append({"key": lb.UniqueId(data_row_dict['data_row_id']),
                      "embeddings": [{"embedding_id": embedding.id, "vector": custom_embedding}]})
    
    print('payload', len(payload),payload[:1])
    

Upload payload

  1. Upsert data rows with custom embeddings

    task = dataset.upsert_data_rows(payload)
    task.wait_till_done()
    print(task.errors)
    print(task.status)
    
  2. Get the count of imported vectors for a custom embedding type

    An updated count can take a few minutes, depending on the number of data rows associated with the embedding type.

    count = embedding.get_imported_vector_count()
    
  3. Delete custom embedding type.

     embedding.delete()
    

Upload custom embeddings during data row creation

  1. Create a dataset

    # Create a dataset
    dataset_new = client.create_dataset(name="data_rows_with_embeddings")
    
  2. Fetch an embedding type and create dummy vector data.

    embedding = client.get_embedding_by_name("my_custom_embedding_2048_dimensions")
    vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)]
    
  3. Upload data rows with embeddings.

    uploads = []
    # Generate data rows
    for i in range(1,9):
        uploads.append({
            "row_data":  f"https://storage.googleapis.com/labelbox-datasets/People_Clothing_Segmentation/jpeg_images/IMAGES/img_000{i}.jpeg",
            "global_key": "TEST-ID-%id" % uuid.uuid1(),
            "embeddings": [{
                        "embedding_id": embedding.id,
                        "vector": vector
                    }]
        })
    
    task1 = dataset_new.create_data_rows(uploads)
    task1.wait_till_done()
    print("ERRORS: " , task1.errors)
    print("RESULTS:" , task1.result)