Run in Google Colab | View source on GitHub |
Use text embeddings to represent text as numerical vectors. This process lets computers understand and process text data, which is essential for many natural language processing (NLP) tasks.
The following NLP tasks use embeddings:
- Semantic search: Find documents or passages that are relevant to a query when the query doesn't use the exact same words as the documents.
- Text classification: Categorize text data into different classes, such as spam and not spam, or positive sentiment and negative sentiment.
- Machine translation: Translate text from one language to another and preserve the meaning.
- Text summarization: Create shorter summaries of text.
This notebook uses Apache Beam's MLTransform
to generate embeddings from text data.
Hugging Face's SentenceTransformers
framework uses Python to generate sentence, text, and image embeddings.
To generate text embeddings that use Hugging Face models and MLTransform
, use the SentenceTransformerEmbeddings
module to specify the model configuration.
Install dependencies
Install Apache Beam and the dependencies needed to work with Hugging Face embeddings. The dependencies includes the sentence-transformers
package, which is required to use the SentenceTransformerEmbeddings
module.
pip install apache_beam>=2.53.0 --quiet
pip install sentence-transformers --quiet
import tempfile
import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings
Process the data
MLTransform
is a PTransform
that you can use for data preparation, including generating text embeddings.
Use MLTransform in write mode
In write
mode, MLTransform
saves the transforms and their attributes to an artifact location. Then, when you run MLTransform
in read
mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy.
For more information about using MLTransform
, see Preprocess data with MLTransform in the Apache Beam documentation.
Get the data
The following text inputs come from the Hugging Face blog Getting Started With Embeddings.
MLTransform
operates on dictionaries of data. To generate embeddings for specific columns, provide the column names as input to the columns
argument in the SentenceTransformerEmbeddings
package.
content = [
{'x': 'How do I get a replacement Medicare card?'},
{'x': 'What is the monthly premium for Medicare Part B?'},
{'x': 'How do I terminate my Medicare Part B (medical insurance)?'},
{'x': 'How do I sign up for Medicare?'},
{'x': 'Can I sign up for Medicare Part B if I am working and have health insurance through an employer?'},
{'x': 'How do I sign up for Medicare Part B if I already have Part A?'},
{'x': 'What are Medicare late enrollment penalties?'},
{'x': 'What is Medicare and who can get it?'},
{'x': 'How can I get help with my Medicare Part A and Part B premiums?'},
{'x': 'What are the different parts of Medicare?'},
{'x': 'Will my Medicare premiums be higher because of my higher income?'},
{'x': 'What is TRICARE ?'},
{'x': "Should I sign up for Medicare Part B if I have Veterans' Benefits?"}
]
text_embedding_model_name = 'sentence-transformers/sentence-t5-large'
# helper function that returns a dict containing only first
# ten elements of generated embeddings
def truncate_embeddings(d):
for key in d.keys():
d[key] = d[key][:10]
return d
Generate text embeddings
This example uses the model sentence-transformers/sentence-t5-large
to generate text embeddings. The model uses only the encoder from a T5-large model
. The weights are stored in FP16. For more information about the model, see Sentence-T5: Scalable Sentence Encoders from Pre-trained Text-to-Text Models.
artifact_location_t5 = tempfile.mkdtemp(prefix='huggingface_')
embedding_transform = SentenceTransformerEmbeddings(
model_name=text_embedding_model_name, columns=['x'])
with beam.Pipeline() as pipeline:
data_pcoll = (
pipeline
| "CreateData" >> beam.Create(content))
transformed_pcoll = (
data_pcoll
| "MLTransform" >> MLTransform(write_artifact_location=artifact_location_t5).with_transform(embedding_transform))
transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)
transformed_pcoll | "PrintEmbeddingShape" >> beam.Map(lambda x: print(f"Embedding shape: {len(x['x'])}"))
{'x': [-0.0317193828523159, -0.005265399813652039, -0.012499183416366577, 0.00018130357784684747, -0.005592408124357462, 0.06207558885216713, -0.01656288281083107, 0.0167048592120409, -0.01239298190921545, 0.03041897714138031]} Embedding shape: 10 {'x': [-0.015295305289328098, 0.005405726842582226, -0.015631258487701416, 0.022797023877501488, -0.027843449264764786, 0.03968179598450661, -0.004387892782688141, 0.022909151390194893, 0.01015392318367958, 0.04723235219717026]} Embedding shape: 10 {'x': [-0.03450256213545799, -0.002632762538269162, -0.022460950538516045, -0.011689935810863972, -0.027329981327056885, 0.07293087989091873, -0.03069353476166725, 0.05429817736148834, -0.01308195199817419, 0.017668722197413445]} Embedding shape: 10 {'x': [-0.02869587577879429, -0.0002648509689606726, -0.007186499424278736, -0.0003750955802388489, 0.012458174489438534, 0.06721009314060211, -0.013404129073023796, 0.03204648941755295, -0.021021844819188118, 0.04968355968594551]} Embedding shape: 10 {'x': [-0.03241290897130966, 0.006845517549663782, 0.02001815102994442, -0.0057969288900494576, 0.008191823959350586, 0.08160955458879471, -0.009215254336595535, 0.023534387350082397, -0.02034241147339344, 0.0357462577521801]} Embedding shape: 10 {'x': [-0.04592451825737953, -0.0025395643897354603, -0.01178023498505354, 0.011568977497518063, -0.0029014083556830883, 0.06971456110477448, -0.021167151629924774, 0.015902182087302208, -0.015007994137704372, 0.026213033124804497]} Embedding shape: 10 {'x': [0.005221465136855841, -0.002127869985997677, -0.002369001042097807, -0.019337018951773643, 0.023243796080350876, 0.05599674955010414, -0.022721167653799057, 0.024813007563352585, -0.010685156099498272, 0.03624529018998146]} Embedding shape: 10 {'x': [-0.035339221358299255, 0.010706206783652306, -0.001701260800473392, -0.00862252525985241, 0.006445988081395626, 0.08198338001966476, -0.022678885608911514, 0.01434261817485094, -0.008092232048511505, 0.03345781937241554]} Embedding shape: 10 {'x': [-0.030748076736927032, 0.009340512566268444, -0.013637945055961609, 0.011183148249983788, -0.013879665173590183, 0.046350326389074326, -0.024090109393000603, 0.02885228954255581, -0.01699884608387947, 0.01672385260462761]} Embedding shape: 10 {'x': [-0.040792081505060196, -0.00872269831597805, -0.015838179737329483, -0.03141209855675697, -7.104632823029533e-05, 0.08301416039466858, -0.034691162407398224, 0.0026397297624498606, 0.009255227632820606, 0.05415954813361168]} Embedding shape: 10 {'x': [-0.02156883291900158, 0.003969342447817326, -0.030446071177721024, 0.008231461979448795, -0.01271845493465662, 0.03793857619166374, -0.013524272479116917, -0.0385628417134285, -0.0058258213102817535, 0.03505263477563858]} Embedding shape: 10 {'x': [-0.027544165030121803, -0.01773364469408989, -0.013286487199366093, -0.008328652940690517, -0.011047529056668282, 0.05237515643239021, -0.016948163509368896, 0.02806701697409153, -0.0018120920285582542, 0.027241172268986702]} Embedding shape: 10 {'x': [-0.03464886546134949, -0.003521248232573271, -0.010239562019705772, -0.018618224188685417, 0.004094886127859354, 0.062059685587882996, -0.013881963677704334, -0.0008639032603241503, -0.029874088242650032, 0.033531222492456436]} Embedding shape: 10
You can pass additional arguments that are supported by sentence-transformer
models, such as convert_to_numpy=False
. These arguments are passed as a dict
to the SentenceTransformerEmbeddings
transform by using the inference_args
parameter.
When you pass convert_to_numpy=False
, the output contains torch.Tensor
matrices.
artifact_location_t5_with_inference_args = tempfile.mkdtemp(prefix='huggingface_')
embedding_transform = SentenceTransformerEmbeddings(
model_name=text_embedding_model_name, columns=['x'],
inference_args={'convert_to_numpy': False}
)
with beam.Pipeline() as pipeline:
data_pcoll = (
pipeline
| "CreateData" >> beam.Create(content))
transformed_pcoll = (
data_pcoll
| "MLTransform" >> MLTransform(write_artifact_location=artifact_location_t5_with_inference_args).with_transform(embedding_transform))
# The outputs are in the PyTorch tensor type.
transformed_pcoll | 'LogOutput' >> beam.Map(lambda x: print(type(x['x'])))
transformed_pcoll | "PrintEmbeddingShape" >> beam.Map(lambda x: print(f"Embedding shape: {len(x['x'])}"))
<class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768 <class 'torch.Tensor'> Embedding shape: 768
Use MLTransform in read mode
In read
mode, MLTransform
uses the artifacts generated during write
mode. In this case, the SentenceTransformEmbedding
transform and its attributes are loaded from the saved artifacts. You don't need to specify the artifacts again during read
mode.
In this way, MLTransform
provides consistent preprocessing steps for training and inference workloads.
test_content = [
{
'x': 'This is a test sentence'
},
{
'x': 'The park is full of dogs'
},
{
'x': "Should I sign up for Medicare Part B if I have Veterans' Benefits?"
}
]
# Uses the T5 model to generate text embeddings
with beam.Pipeline() as pipeline:
data_pcoll = (
pipeline
| "CreateData" >> beam.Create(test_content))
transformed_pcoll = (
data_pcoll
| "MLTransform" >> MLTransform(read_artifact_location=artifact_location_t5))
transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)
{'x': [0.00036313451710157096, -0.03929319977760315, -0.03574873134493828, 0.05015222355723381, 0.04295048117637634, 0.04800170287489891, 0.006883862894028425, -0.02567591704428196, -0.048067063093185425, 0.036534328013658524]} {'x': [-0.053793832659721375, 0.006730600260198116, -0.025130020454525948, 0.04363932088017464, 0.03323192894458771, 0.008803879842162132, -0.015412433072924614, 0.008926985785365105, -0.061175212264060974, 0.04573329910635948]} {'x': [-0.03464885801076889, -0.003521254053339362, -0.010239563882350922, -0.018618224188685417, 0.004094892647117376, 0.062059689313173294, -0.013881963677704334, -0.000863900815602392, -0.029874078929424286, 0.03353121876716614]}