Skip to contents

Extract hidden layers from a language model and aggregate them to get token (roughly word) embeddings and text embeddings (all reshaped to embed matrix). It is a wrapper function of text::textEmbed().

Usage

text_to_vec(
  text,
  model,
  layers = "all",
  layer.to.token = "concatenate",
  token.to.word = TRUE,
  token.to.text = TRUE,
  encoding = "UTF-8",
  ...
)

Arguments

text

Can be:

  • a character string or vector of text (usually sentences)

  • a data frame with at least one character variable (for text from all character variables in a given data frame)

  • a file path on disk containing text

model

Model name at HuggingFace. See text_model_download. If the model has not been downloaded, it would automatically download the model.

layers

Layers to be extracted from the model, which are then aggregated in the function text::textEmbedLayerAggregation(). Defaults to "all" which extracts all layers. You may extract only the layers you need (e.g., 11:12). Note that layer 0 is the decontextualized input layer (i.e., not comprising hidden states).

layer.to.token

Method to aggregate hidden layers to each token. Defaults to "concatenate", which links together each word embedding layer to one long row. Options include "mean", "min", "max", and "concatenate".

token.to.word

Aggregate subword token embeddings (if whole word is out of vocabulary) to whole word embeddings. Defaults to TRUE, which sums up subword token embeddings.

token.to.text

Aggregate token embeddings to each text. Defaults to TRUE, which averages all token embeddings. If FALSE, the text embedding will be the token embedding of [CLS] (the special token that is used to represent the beginning of a text sequence).

encoding

Text encoding (only used if text is a file). Defaults to "UTF-8".

...

Other parameters passed to text::textEmbed().

Value

A list of:

token.embed

Token (roughly word) embeddings

text.embed

Text embeddings, aggregated from token embeddings

Examples

if (FALSE) {
# text_init()  # initialize the environment

text = c("Download models from HuggingFace",
         "Chinese are East Asian",
         "Beijing is the capital of China")
embed = text_to_vec(text, model="bert-base-cased", layers=c(0, 12))
embed

embed1 = embed$token.embed[[1]]
embed2 = embed$token.embed[[2]]
embed3 = embed$token.embed[[3]]

View(embed1)
View(embed2)
View(embed3)
View(embed$text.embed)

plot_similarity(embed1, value.color="grey")
plot_similarity(embed2, value.color="grey")
plot_similarity(embed3, value.color="grey")
plot_similarity(rbind(embed1, embed2, embed3))
}