Bootstrap

大模型 Spring AI AbstractEmbeddingModel

在Spring AI框架中,AbstractEmbeddingModel是一个关键的抽象类,用于为不同的嵌入模型提供统一的接口和抽象层。以下是对AbstractEmbeddingModel的详细解析,包括其定义、作用、核心功能、使用场景以及示例代码。

一、定义与作用

AbstractEmbeddingModel是Spring AI框架中用于嵌入模型的抽象基类。它定义了一套统一的接口,使得开发者可以轻松地集成和使用不同的嵌入模型,而无需深入了解每个模型的底层实现细节。这个抽象类的主要作用是提供一个通用的框架,以便开发者能够根据需要选择和切换嵌入模型,同时保持代码的一致性和可维护性。

二、核心功能

AbstractEmbeddingModel通过继承自EmbeddingModel接口的方法,实现了以下核心功能:

  1. 文本向量化:能够将输入的文本数据转换为连续的数值向量表示。这是嵌入模型最基本也是最重要的功能之一。
  2. 批量处理:支持对一批文本进行向量化处理,提高了处理效率。
  3. 维度一致性:不同嵌入模型生成的向量可能具有不同的维度,但AbstractEmbeddingModel通过统一的接口和抽象层,确保了向量维度的一致性,便于后续处理和分析。
  4. 可移植性和简单性:AbstractEmbeddingModel的设计符合Spring框架的模块化和可互换性原则,允许开发者在不同的嵌入技术或模型之间进行切换时,只需进行最少的代码更改。同时,它也简化了将文本转换为嵌入的过程,消除了处理原始文本数据和嵌入算法的复杂性。

三、使用场景

AbstractEmbeddingModel在Spring AI框架中具有广泛的应用场景,包括但不限于:

  1. 自然语言处理:将单词或句子转换成向量表示,用于文本分类、机器翻译、情感分析等任务。
  2. 推荐系统:将用户和产品映射成向量表示,以便更好地理解用户的喜好和匹配物品。
  3. 检索增强生成(RAG):在RAG技术中,嵌入模型用于将文档转换为向量数据并存储在向量数据库中,以便后续通过自然语言查询来检索相关数据。

四、实现与配置

在Spring AI框架中实现和使用AbstractEmbeddingModel通常涉及以下几个步骤:

  1. 选择嵌入模型:根据具体的应用场景和需求,选择合适的嵌入模型(如Word2Vec、GloVe、BERT等)并配置其相关信息。
  2. 集成嵌入模型:通过继承AbstractEmbeddingModel并实现其抽象方法,将选定的嵌入模型集成到Spring AI框架中。
  3. 配置Spring AI:在Spring AI的配置文件中(如application.yml或application.properties),配置嵌入模型的相关信息,如模型名称、基础URL等。
  4. 使用嵌入模型:在应用程序中,通过注入EmbeddingModel或AbstractEmbeddingModel的实例,调用其方法将文本数据转换为向量表示,并进行后续的处理和分析。

五、示例代码

以下是一个使用AbstractEmbeddingModel的示例代码,展示了如何将文本数据转换为向量表示:

// 假设已经有一个具体的嵌入模型实现类MyEmbeddingModel,它继承自AbstractEmbeddingModel
@Resource
private MyEmbeddingModel embeddingModel;
 
public void embedText(String text) {
    // 将文本转换为向量
    List<Double> vector = embeddingModel.embed(text);
    
    // 输出向量结果
    System.out.println("Text embedding vector: " + vector);
}
 
// 批量处理文本
public void embedTexts(List<String> texts) {
    // 将一批文本转换为向量
    List<List<Double>> vectors = embeddingModel.embed(texts);
    
    // 输出向量结果
    for (int i = 0; i < vectors.size(); i++) {
        System.out.println("Text " + (i + 1) + " embedding vector: " + vectors.get(i));
    }
}

在上面的示例代码中,MyEmbeddingModel是AbstractEmbeddingModel的一个具体实现类。我们通过注入MyEmbeddingModel的实例,并调用其embed方法将文本数据转换为向量表示。同时,也展示了如何对一批文本进行批量处理。

示例:OllamaEmbeddingModel

public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
    private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
    private final OllamaApi ollamaApi;
    private final OllamaOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final OllamaModelManager modelManager;
    private EmbeddingModelObservationConvention observationConvention;

    public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(ollamaApi, "ollamaApi must not be null");
        Assert.notNull(defaultOptions, "options must not be null");
        Assert.notNull(observationRegistry, "observationRegistry must not be null");
        Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
        this.ollamaApi = ollamaApi;
        this.defaultOptions = defaultOptions;
        this.observationRegistry = observationRegistry;
        this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
        this.initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
    }

    public static Builder builder() {
        return new Builder();
    }

    public float[] embed(Document document) {
        return this.embed(document.getText());
    }

    public EmbeddingResponse call(EmbeddingRequest request) {
        Assert.notEmpty(request.getInstructions(), "At least one text is required!");
        OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = this.ollamaEmbeddingRequest(request.getInstructions(), request.getOptions());
        EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder().embeddingRequest(request).provider(OllamaApi.PROVIDER_NAME).requestOptions(this.buildRequestOptions(ollamaEmbeddingRequest)).build();
        return (EmbeddingResponse)EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            OllamaApi.EmbeddingsResponse response = this.ollamaApi.embed(ollamaEmbeddingRequest);
            AtomicInteger indexCounter = new AtomicInteger(0);
            List<Embedding> embeddings = response.embeddings().stream().map((e) -> new Embedding(e, indexCounter.getAndIncrement())).toList();
            EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(), OllamaEmbeddingUsage.from(response));
            EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, embeddingResponseMetadata);
            observationContext.setResponse(embeddingResponse);
            return embeddingResponse;
        });
    }

    OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, EmbeddingOptions options) {
        OllamaOptions runtimeOptions = null;
        if (options != null && options instanceof OllamaOptions ollamaOptions) {
            runtimeOptions = ollamaOptions;
        }

        OllamaOptions mergedOptions = (OllamaOptions)ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
        if (!StringUtils.hasText(mergedOptions.getModel())) {
            throw new IllegalArgumentException("Model is not set!");
        } else {
            String model = mergedOptions.getModel();
            return new OllamaApi.EmbeddingsRequest(model, inputContent, OllamaEmbeddingModel.DurationParser.parse(mergedOptions.getKeepAlive()), OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
        }
    }

    private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request) {
        return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
    }

    private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
        if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
            this.modelManager.pullModel(model, pullModelStrategy);
        }

    }

    public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
        Assert.notNull(observationConvention, "observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static class DurationParser {
        private static final Pattern PATTERN = Pattern.compile("(\\d+)(ms|s|m|h)");

        public DurationParser() {
        }

        public static Duration parse(String input) {
            if (!StringUtils.hasText(input)) {
                return null;
            } else {
                Matcher matcher = PATTERN.matcher(input);
                if (matcher.matches()) {
                    long value = Long.parseLong(matcher.group(1));
                    Duration var10000;
                    switch (matcher.group(2)) {
                        case "ms" -> var10000 = Duration.ofMillis(value);
                        case "s" -> var10000 = Duration.ofSeconds(value);
                        case "m" -> var10000 = Duration.ofMinutes(value);
                        case "h" -> var10000 = Duration.ofHours(value);
                        default -> throw new IllegalArgumentException("Unsupported time unit: " + unit);
                    }

                    return var10000;
                } else {
                    throw new IllegalArgumentException("Invalid duration format: " + input);
                }
            }
        }
    }

    public static final class Builder {
        private OllamaApi ollamaApi;
        private OllamaOptions defaultOptions;
        private ObservationRegistry observationRegistry;
        private ModelManagementOptions modelManagementOptions;

        private Builder() {
            this.defaultOptions = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();
            this.observationRegistry = ObservationRegistry.NOOP;
            this.modelManagementOptions = ModelManagementOptions.defaults();
        }

        public Builder ollamaApi(OllamaApi ollamaApi) {
            this.ollamaApi = ollamaApi;
            return this;
        }

        public Builder defaultOptions(OllamaOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
            this.modelManagementOptions = modelManagementOptions;
            return this;
        }

        public OllamaEmbeddingModel build() {
            return new OllamaEmbeddingModel(this.ollamaApi, this.defaultOptions, this.observationRegistry, this.modelManagementOptions);
        }
    }
}

注意事项

    1. 选择合适的嵌入模型:根据具体的应用场景和需求,选择合适的嵌入模型并配置其相关信息。
    1. 性能考虑:嵌入模型的性能可能会影响整个系统的响应时间。因此,在选择嵌入模型时,需要考虑其处理速度和资源消耗。
    1. 维度一致性:确保不同嵌入模型生成的向量具有一致的维度,以便后续处理和分析。如果维度不一致,可能需要进行额外的转换或处理。

综上所述,AbstractEmbeddingModel在Spring AI框架中起到了至关重要的作用,它为集成和使用不同的嵌入模型提供了基础支持,并在自然语言处理、推荐系统以及检索增强生成等领域具有广泛的应用前景。

;