在Spring AI框架中,AbstractEmbeddingModel是一个关键的抽象类,用于为不同的嵌入模型提供统一的接口和抽象层。以下是对AbstractEmbeddingModel的详细解析,包括其定义、作用、核心功能、使用场景以及示例代码。
一、定义与作用
AbstractEmbeddingModel是Spring AI框架中用于嵌入模型的抽象基类。它定义了一套统一的接口,使得开发者可以轻松地集成和使用不同的嵌入模型,而无需深入了解每个模型的底层实现细节。这个抽象类的主要作用是提供一个通用的框架,以便开发者能够根据需要选择和切换嵌入模型,同时保持代码的一致性和可维护性。
二、核心功能
AbstractEmbeddingModel通过继承自EmbeddingModel接口的方法,实现了以下核心功能:
- 文本向量化:能够将输入的文本数据转换为连续的数值向量表示。这是嵌入模型最基本也是最重要的功能之一。
- 批量处理:支持对一批文本进行向量化处理,提高了处理效率。
- 维度一致性:不同嵌入模型生成的向量可能具有不同的维度,但AbstractEmbeddingModel通过统一的接口和抽象层,确保了向量维度的一致性,便于后续处理和分析。
- 可移植性和简单性:AbstractEmbeddingModel的设计符合Spring框架的模块化和可互换性原则,允许开发者在不同的嵌入技术或模型之间进行切换时,只需进行最少的代码更改。同时,它也简化了将文本转换为嵌入的过程,消除了处理原始文本数据和嵌入算法的复杂性。
三、使用场景
AbstractEmbeddingModel在Spring AI框架中具有广泛的应用场景,包括但不限于:
- 自然语言处理:将单词或句子转换成向量表示,用于文本分类、机器翻译、情感分析等任务。
- 推荐系统:将用户和产品映射成向量表示,以便更好地理解用户的喜好和匹配物品。
- 检索增强生成(RAG):在RAG技术中,嵌入模型用于将文档转换为向量数据并存储在向量数据库中,以便后续通过自然语言查询来检索相关数据。
四、实现与配置
在Spring AI框架中实现和使用AbstractEmbeddingModel通常涉及以下几个步骤:
- 选择嵌入模型:根据具体的应用场景和需求,选择合适的嵌入模型(如Word2Vec、GloVe、BERT等)并配置其相关信息。
- 集成嵌入模型:通过继承AbstractEmbeddingModel并实现其抽象方法,将选定的嵌入模型集成到Spring AI框架中。
- 配置Spring AI:在Spring AI的配置文件中(如application.yml或application.properties),配置嵌入模型的相关信息,如模型名称、基础URL等。
- 使用嵌入模型:在应用程序中,通过注入EmbeddingModel或AbstractEmbeddingModel的实例,调用其方法将文本数据转换为向量表示,并进行后续的处理和分析。
五、示例代码
以下是一个使用AbstractEmbeddingModel的示例代码,展示了如何将文本数据转换为向量表示:
// 假设已经有一个具体的嵌入模型实现类MyEmbeddingModel,它继承自AbstractEmbeddingModel
@Resource
private MyEmbeddingModel embeddingModel;
public void embedText(String text) {
// 将文本转换为向量
List<Double> vector = embeddingModel.embed(text);
// 输出向量结果
System.out.println("Text embedding vector: " + vector);
}
// 批量处理文本
public void embedTexts(List<String> texts) {
// 将一批文本转换为向量
List<List<Double>> vectors = embeddingModel.embed(texts);
// 输出向量结果
for (int i = 0; i < vectors.size(); i++) {
System.out.println("Text " + (i + 1) + " embedding vector: " + vectors.get(i));
}
}
在上面的示例代码中,MyEmbeddingModel是AbstractEmbeddingModel的一个具体实现类。我们通过注入MyEmbeddingModel的实例,并调用其embed方法将文本数据转换为向量表示。同时,也展示了如何对一批文本进行批量处理。
示例:OllamaEmbeddingModel
public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
private final OllamaApi ollamaApi;
private final OllamaOptions defaultOptions;
private final ObservationRegistry observationRegistry;
private final OllamaModelManager modelManager;
private EmbeddingModelObservationConvention observationConvention;
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
this.ollamaApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
this.initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}
public static Builder builder() {
return new Builder();
}
public float[] embed(Document document) {
return this.embed(document.getText());
}
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = this.ollamaEmbeddingRequest(request.getInstructions(), request.getOptions());
EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder().embeddingRequest(request).provider(OllamaApi.PROVIDER_NAME).requestOptions(this.buildRequestOptions(ollamaEmbeddingRequest)).build();
return (EmbeddingResponse)EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
OllamaApi.EmbeddingsResponse response = this.ollamaApi.embed(ollamaEmbeddingRequest);
AtomicInteger indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = response.embeddings().stream().map((e) -> new Embedding(e, indexCounter.getAndIncrement())).toList();
EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(), OllamaEmbeddingUsage.from(response));
EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, embeddingResponseMetadata);
observationContext.setResponse(embeddingResponse);
return embeddingResponse;
});
}
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, EmbeddingOptions options) {
OllamaOptions runtimeOptions = null;
if (options != null && options instanceof OllamaOptions ollamaOptions) {
runtimeOptions = ollamaOptions;
}
OllamaOptions mergedOptions = (OllamaOptions)ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
} else {
String model = mergedOptions.getModel();
return new OllamaApi.EmbeddingsRequest(model, inputContent, OllamaEmbeddingModel.DurationParser.parse(mergedOptions.getKeepAlive()), OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
}
}
private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request) {
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
}
private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
}
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}
public static class DurationParser {
private static final Pattern PATTERN = Pattern.compile("(\\d+)(ms|s|m|h)");
public DurationParser() {
}
public static Duration parse(String input) {
if (!StringUtils.hasText(input)) {
return null;
} else {
Matcher matcher = PATTERN.matcher(input);
if (matcher.matches()) {
long value = Long.parseLong(matcher.group(1));
Duration var10000;
switch (matcher.group(2)) {
case "ms" -> var10000 = Duration.ofMillis(value);
case "s" -> var10000 = Duration.ofSeconds(value);
case "m" -> var10000 = Duration.ofMinutes(value);
case "h" -> var10000 = Duration.ofHours(value);
default -> throw new IllegalArgumentException("Unsupported time unit: " + unit);
}
return var10000;
} else {
throw new IllegalArgumentException("Invalid duration format: " + input);
}
}
}
}
public static final class Builder {
private OllamaApi ollamaApi;
private OllamaOptions defaultOptions;
private ObservationRegistry observationRegistry;
private ModelManagementOptions modelManagementOptions;
private Builder() {
this.defaultOptions = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();
this.observationRegistry = ObservationRegistry.NOOP;
this.modelManagementOptions = ModelManagementOptions.defaults();
}
public Builder ollamaApi(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
return this;
}
public Builder defaultOptions(OllamaOptions defaultOptions) {
this.defaultOptions = defaultOptions;
return this;
}
public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}
public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
this.modelManagementOptions = modelManagementOptions;
return this;
}
public OllamaEmbeddingModel build() {
return new OllamaEmbeddingModel(this.ollamaApi, this.defaultOptions, this.observationRegistry, this.modelManagementOptions);
}
}
}
注意事项
- 选择合适的嵌入模型:根据具体的应用场景和需求,选择合适的嵌入模型并配置其相关信息。
- 性能考虑:嵌入模型的性能可能会影响整个系统的响应时间。因此,在选择嵌入模型时,需要考虑其处理速度和资源消耗。
- 维度一致性:确保不同嵌入模型生成的向量具有一致的维度,以便后续处理和分析。如果维度不一致,可能需要进行额外的转换或处理。
综上所述,AbstractEmbeddingModel在Spring AI框架中起到了至关重要的作用,它为集成和使用不同的嵌入模型提供了基础支持,并在自然语言处理、推荐系统以及检索增强生成等领域具有广泛的应用前景。