Bootstrap

Springboot集成阿里云通义千问(灵积模型)

我这里集成后,做成了一个工具jar包,如果有不同方式的,欢迎大家讨论,共同进步。

集成限制:

1、灵积模型有QPM(QPS)限制,每个模型不一样,需要根据每个模型适配

集成开发思路:

因有QPS限制,无法支持多任务并发执行,所以使用任务池操作,定时监听任务池中任务状态;

因系统中执行不能等待QPS释放后执行,故使用异步调用;

开发思路:

1、创建任务,提交到任务池中

2、任务监听器每10秒检查任务池中的任务执行情况:

        1)任务未执行:获任务token,获取到执行任务,否则不执行

        2)任务执行中:判断任务执行是否超时,如果超时,重置任务状态,重试计数加1

        3)任务执行失败:执行失败回调。从任务池中清除

        4)任务执行成功:从任务池中清除

3、任务执行:

        1)获取任务token,如果获取到就执行,否则不执行

        2)利用工具类请求灵积模型

        3)判断任务执行状态:成功:执行成功回调;失败:重试计数加1,重置任务状态

        4)归还token

集成编码

1、前置操作

详见阿里云灵积模型服务开发参开icon-default.png?t=O83Ahttps://help.aliyun.com/zh/dashscope/developer-reference/acquisition-and-configuration-of-api-key?spm=a2c4g.11186623.0.0.1403193eLiHQfl

开发参考中获取到的API-KEY需要写到项目的配置文件中

2、创建灵积服务jar(aliyun-dashscope)

按照灵积模型Java jdk最佳实践的方式实现集成模型灵积模型Java jdk最佳实践icon-default.png?t=O83Ahttps://help.aliyun.com/zh/dashscope/java-sdk-best-practices?spm=a2c4g.11186623.0.0.4da417d9T9NKfMpom文件中引入jar

<dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>dashscope-sdk-java</artifactId>
            <version>2.15.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-pool2</artifactId>
        </dependency>
        <dependency>
            <groupId>com.aa.bb</groupId>
            <artifactId>common-redis</artifactId>
            <version>1.0.0</version>
        </dependency>

dashscope-sdk-java : 灵积服务模型jar

commons-pool2 : 对象池工具jar

common-redis :个人项目中redis工具包(可以自己封装一个)

3、编码

1)创建config

@Data
@Configuration
@ConfigurationProperties(prefix = "aliyun.dashscope")
public class DashScopeConfig {

    /**
     * api密钥
     */
    @Value("${aliyun.dashscope.apiKey}")
    private String apiKey;
    /**
     * 最大tokens数
     */
    private int maxTokens = 800;
    /**
     * 模型
     */
    private String model = "qwen-plus";
    /**
     * QPS
     */
    private int qps = 15;
    /**
     * qps缓存密钥
     */
    private String qpsRedisKey = "aliyun:dashscope:token";
    /**
     * 尝试计数
     */
    private int tryCount = 3;
    /**
     * task间隔时间
     */
    private int time = 10000;


}

2)创建对象池工厂

public class DashScopePoolFactory extends BasePooledObjectFactory<Generation> {
    @Override
    public Generation create() throws Exception {
        return new Generation();
    }

    @Override
    public PooledObject<Generation> wrap(Generation generation) {
        return new DefaultPooledObject<>(generation);
    }
}

 3)创建task

DashTask:任务类

@Data
@Slf4j
public class DashTask {

    /**
     * qps令牌
     */
    private Long qpsToken;

    /**
     * 正在执行
     */
    private boolean execute = false;

    /**
     * 成功
     */
    private boolean success = false;

    /**
     * 尝试计数
     */
    private int tryCount = 0;

    /**
     * 生成参数
     */
    private GenerationParam generationParam;

    /**
     * 结果
     */
    private Message result;

    /**
     * 成功回调
     */
    private Consumer<DashTask> successCallback;

    /**
     * 失败回调
     */
    private Consumer<DashTask> failCallback;

    public void setSuccess(boolean success) {
        if (success) {
            this.onSuccess();
        } else {
            this.onFail();
        }
    }

    /**
     * 论成功
     */
    public void onSuccess() {
        this.success = true;
        try {
            if (this.successCallback != null) {
                this.successCallback.accept(this);
            }
        } catch (Exception ex) {
            log.error("dash task onSuccess error:" + ex.getMessage());
        }
    }

    /**
     * 失败
     */
    public void onFail() {
        this.success = false;
        try {
            if (this.failCallback != null) {
                this.failCallback.accept(this);
            }
        } catch (Exception ex) {
            log.error("spark task onFail error:" + ex.getMessage());
        }
    }

}

DashListener:任务监听类 

@Slf4j
public class DashListener extends Listener {

    public DashListener(long interval) {
        super(interval, "dash-listener");
    }

    @Override
    public void run() {
        log.info("灵积服务(通义千问)任务监听 start");
        setExecute(true);
        while (isExecute()) {
            try {
                DashScopeUtils.asyncTaskStart();
                Thread.sleep(getInterval());
            } catch (Exception e) {
                log.error("灵积服务(通义千问)任务监听 error", e);
            }
        }
    }
}

4)创建工具类

DashScopeUtils:灵积模型基础工具类

@Slf4j
public class DashScopeUtils {

    private static volatile DashScopeConfig config;

    private static volatile RedisService redisService;

    /**
     * 获取令牌
     */
    public static final int GET_TOKEN_STATUS = 0;
    /**
     * 归还令牌
     */
    public static final int BACK_TOKEN_STATUS = 1;

    private static CopyOnWriteArraySet<DashTask> taskList = new CopyOnWriteArraySet<DashTask>();

    /**
     * 通用池
     */
    private static volatile GenericObjectPool<Generation> pool;

    /**
     * 创建消息
     *
     * @param role    角色
     * @param content 所容纳之物
     * @return {@link Message }
     */
    public static Message createMessage(Role role, String content) {
        return Message.builder().role(role.getValue()).content(content).build();
    }


    /**
     * 调用服务
     *
     * @param param param
     * @return {@link GenerationResult }
     */
    public static GenerationResult call(GenerationParam param) {
        try {
            if (param.getMaxTokens() == null) {
                param.setMaxTokens(getConfig().getMaxTokens());
            }
            Generation gen = getPool().borrowObject();
            GenerationResult call = gen.call(param);
            getPool().returnObject(gen);
            return call;
        } catch (Exception e) {
            log.error(e.getMessage(), e);
            throw new RuntimeException(e.getMessage());
        }
    }

    /**
     * 获取对象池
     *
     * @return {@link GenericObjectPool }<{@link Generation }>
     */
    public static GenericObjectPool<Generation> getPool() {
        if (pool == null) {
            synchronized (DashScopeUtils.class) {
                if (pool == null) {
                    DashScopePoolFactory poolFactory = new DashScopePoolFactory();
                    GenericObjectPoolConfig<Generation> config = new GenericObjectPoolConfig<>();
                    config.setMaxTotal(64);
                    config.setMaxIdle(64);
                    config.setMinIdle(64);
                    Constants.apiKey = getConfig().getApiKey();
                    pool = new GenericObjectPool<>(poolFactory, config);
                }
            }
        }
        return pool;
    }

    /**
     * 获取配置
     *
     * @return {@link DashScopeConfig }
     */
    public static DashScopeConfig getConfig() {
        if (config == null) {
            synchronized (DashScopeConfig.class) {
                if (config == null) {
                    config = SpringUtils.getBean(DashScopeConfig.class);
                }
            }
        }
        return config;
    }

    /**
     * 异步任务启动
     */
    public static void asyncTaskStart() {
        instanceRedis();
        getConfig();
        // 令牌数量
        int current = 0;
        if (redisService.hasKey(config.getQpsRedisKey())) {
            current = Integer.parseInt(redisService.get(config.getQpsRedisKey()).toString());
        }
        if (current > 0) {
            String all = config.getQpsRedisKey() + ":*";
            int size = redisService.keys(all).size();
            if (size < current) {
                redisService.decr(config.getQpsRedisKey(), current - size);
            }
        }
        if (!taskList.isEmpty()) {
            Iterator<DashTask> iterator = taskList.iterator();
            while (iterator.hasNext()) {
                DashTask dashTask = iterator.next();
                if (dashTask.isExecute()) {
                    if (!redisService.hasKey(config.getQpsRedisKey() + ":" + dashTask.getQpsToken())) {
                        dashTask.setExecute(false);
                        dashTask.setTryCount(dashTask.getTryCount()+1);
                    }
                    continue;
                } else if (dashTask.isSuccess()) {
                    taskList.remove(dashTask);
                } else if (dashTask.getTryCount() > config.getTryCount()) {
                    dashTask.setSuccess(false);
                    taskList.remove(dashTask);
                } else if (!asyncTaskStart(dashTask)) {
                    break;
                }
            }
        }
    }

    /**
     * 提交任务
     *
     * @param dashTask 短跑任务
     */
    public static void submitTask(DashTask dashTask) {
        taskList.add(dashTask);
    }

    /**
     * 异步任务启动
     *
     * @param task 任务
     * @return boolean
     */
    private static boolean asyncTaskStart(DashTask task) {
        if (qpsToken(GET_TOKEN_STATUS, task)) {
            AsyncManager.me().execute(() -> {
                try {
                    task.setExecute(true);
                    GenerationResult call = call(task.getGenerationParam());
                    task.setResult(call.getOutput().getChoices().get(0).getMessage());
                    task.setSuccess(true);
                } catch (Exception e) {
                    task.setTryCount(task.getTryCount() + 1);
                }
                task.setExecute(false);
                qpsToken(BACK_TOKEN_STATUS, task);
            });
            return true;
        }
        return false;
    }


    /**
     * qps令牌
     *
     * @param status 地位
     * @param task   任务
     * @return boolean
     */
    private static synchronized boolean qpsToken(int status, DashTask task) {
        instanceRedis();
        getConfig();
        int current = 0;
        if (redisService.hasKey(config.getQpsRedisKey())) {
            current = Integer.parseInt(redisService.get(config.getQpsRedisKey()).toString());
        }

        // 获取token
        if (status == GET_TOKEN_STATUS) {
            if (current < config.getQps()) {
                Long incr = redisService.incr(config.getQpsRedisKey());
                task.setQpsToken(incr);
                redisService.set(config.getQpsRedisKey() + ":" + incr, "1", 1, TimeUnit.MINUTES);
                return true;
            } else {
                return false;
            }
        } else {
            if (current > 0) {
                redisService.decr(config.getQpsRedisKey());
            }
            redisService.del(config.getQpsRedisKey() + ":" + task.getQpsToken());
            return true;
        }
    }

    /**
     * 实例redis
     *
     * @return {@link RedisService}
     */
    private static RedisService instanceRedis() {
        if (redisService == null) {
            synchronized (DashScopeUtils.class) {
                if (redisService == null) {
                    redisService = SpringUtils.getBean(RedisService.class);
                }
                if (redisService == null) {
                    throw new RuntimeException("redisService is null");
                }
            }
        }
        return redisService;
    }
}

 QiamwenUtils:通义千问工具类


public class QianWenUtils {

    /**
     * 单轮对话
     *
     * @param content 内容
     * @param success 成功
     */
    public static void call(String content, Consumer<Message> success) {
        Message message = DashScopeUtils.createMessage(Role.USER, content);
        call(Collections.singletonList(message), success);
    }

    /**
     * 多轮对话
     *
     * @param messages 对话列表
     * @return {@link Message }
     */
    public static void call(List<Message> messages, Consumer<Message> success) {
        try {
            GenerationParam param = GenerationParam.builder()
                    .model(DashScopeUtils.getConfig().getModel())
                    .messages(messages)
                    .resultFormat(GenerationParam.ResultFormat.MESSAGE)
                    .topP(0.8)
                    .maxTokens(600)
                    .build();
            DashTask dashTask = new DashTask();
            dashTask.setGenerationParam(param);
            dashTask.setSuccessCallback(dash -> success.accept(dash.getResult()));
            DashScopeUtils.submitTask(dashTask);
        } catch (Exception e) {
            throw new RuntimeException("通义千问失败:" + e.getMessage());
        }
    }
}

5)创建runner

runner主要作用:

(1)检查配置文件是否正确配置;

(2)启动任务监听器

@Slf4j
@Component
public class DashScopeRunner {

    private DashListener dashListener;

    @PostConstruct
    public void run() {
        DashScopeConfig config = DashScopeUtils.getConfig();
        if (config == null || ObjectUtil.isEmpty(config.getApiKey())) {
            throw new RuntimeException("灵积服务(通义千问)启动失败,请检查配置文件");
        } else {
            log.info("灵积服务(通义千问)启动");
        }
        dashListener = new DashListener(config.getTime());
        dashListener.start();
    }

    @PostConstruct
    public void shutdown() {
        if (dashListener != null) {
            dashListener.shutdown();
        }
    }
}

4、测试 

5、踩坑

1)token数量验证:每次开始执行任务池中任务状态检查时,要先检查任务token是否和实际一致,避免实际可用token数不足,导致进入死循环

2)任务池中的数据不能使用缓存(redis)

3)成功和失败回调必须是public

4)使用对象池(GenericObjectPool),借出对象,使用完成后必须归还,否则会出现无法借出的情况

5)config中QPS最好小于15,否则会出现限流情况

;