Bootstrap

LlamaIndex实现 JSON结构化输出的反射工作流程

结构化输出的 Reflection 工作流程

本笔记本将介绍如何设置 一个LlamaIndex的工作流,从用户的提问中,提取结构化对象的关键信息,以便通过重试和对错误进行反思来提供可靠的JSON结构化文本输出的。这在用户希望LLM能按照用户意愿生成标准的Json格式文本场景中,非常有用!

本文中使用的是国内阿里云的DashScopeLLM

pip install -U llama-index-llms-dashscope

由于工作流首先是异步的,因此这一切都在 Notebook 中运行良好。如果您在自己的代码中运行,则希望使用它来启动异步事件循环(如果尚未运行)。asyncio.run()

async def main():
    <async code>

if __name__ == "__main__":
    import asyncio
    asyncio.run(main())

设计工作流程

要验证 LLM 的结构化输出,我们只需要两个步骤:

  • 生成结构化输出
  • 验证输出是否为正确的 JSON

这里的关键是,如果 output 无效,我们就会循环直到它无效,并将错误反馈提供给下一代。

工作流事件

要处理这些步骤,我们需要定义一些事件:

  • 要在生成的提取中传递的事件
  • 在提取无效时提供反馈的事件

其他步骤将使用内置的StartEvent和StopEvent事件。

from llama_index.core.workflow import Event


class ExtractionDone(Event):
    output: str
    passage: str


class ValidationErrorEvent(Event):
    error: str
    wrong_output: str
    passage: str

要提取的项目

为了提示我们的模型,让我们定义一个我们想要提取的 pydantic 模型。

from pydantic import BaseModel


class Car(BaseModel):
    brand: str
    model: str
    power: int


class CarCollection(BaseModel):
    cars: list[Car]

工作流本身

定义事件后,我们可以构建工作流和步骤。

请注意,工作流会使用类型注释自动验证自身,因此我们步骤中的类型注释非常有用!

import json

from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)
from llama_index.llms.ollama import Ollama


EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------

Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}

"""

REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------

This caused the JSON decode error: {error}

Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""


class ReflectionWorkflow(Workflow):
    max_retries: int = 3

    @step
    async def extract(
        self, ctx: Context, ev: StartEvent | ValidationErrorEvent
    ) -> StopEvent | ExtractionDone:
        current_retries = await ctx.get("retries", default=0)
        if current_retries >= self.max_retries:
            return StopEvent(result="Max retries reached")
        else:
            await ctx.set("retries", current_retries + 1)

        if isinstance(ev, StartEvent):
            passage = ev.get("passage")
            if not passage:
                return StopEvent(result="Please provide some text in input")
            reflection_prompt = ""
        elif isinstance(ev, ValidationErrorEvent):
            passage = ev.passage
            reflection_prompt = REFLECTION_PROMPT.format(
                wrong_answer=ev.wrong_output, error=ev.error
            )

        llm = Ollama(model="llama3", request_timeout=30)
        prompt = EXTRACTION_PROMPT.format(
            passage=passage, schema=CarCollection.schema_json()
        )
        if reflection_prompt:
            prompt += reflection_prompt

        output = await llm.acomplete(prompt)

        return ExtractionDone(output=str(output), passage=passage)

    @step
    async def validate(
        self, ev: ExtractionDone
    ) -> StopEvent | ValidationErrorEvent:
        try:
            CarCollection.model_validate_json(ev.output)
        except Exception as e:
            print("Validation failed, retrying...")
            return ValidationErrorEvent(
                error=str(e), wrong_output=ev.output, passage=ev.passage
            )

        return StopEvent(result=ev.output)

就是这样!让我们稍微探索一下我们编写的工作流程。

  • 我们有一个入口点(接受extractStartEvent)
  • 完成后,它会发出一个事件extractExtraction
  • validate运行并确认提取:
    • 如果正常,它将发出并停止工作流StopEvent
    • 如果不是,它返回一个带有关于ValidationErrorEvent的信息
  • 任何发出ValidationErrorEvent的信息都将触发循环,并再次运行!
  • 这将一直持续到验证结构化输出为止

运行 Workflow!

注意:对于循环,我们需要注意运行时。这里,我们设置了 120 秒的超时时间。

w = ReflectionWorkflow(timeout=120, verbose=True)
# 运行工作流
ret = await w.run(
    passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
)

Running step extract
Step extract produced event ExtractionDone
Running step validate
Validation failed, retrying…
Step validate produced event ValidationErrorEvent
Running step extract
Step extract produced event ExtractionDone
Running step validate
Step validate produced event StopEvent

print(ret)

{ “cars”: [ { “brand”: “Fiat”, “model”: “Panda”, “power”: 45 }, { “brand”: “Honda”, “model”: “Civic”, “power”: 330 } ] }

完整代码

import asyncio

from llama_index.core.workflow import (
    Event,
    Context,
    StartEvent,
    StopEvent,
    Workflow,
    step, )
from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels
from pydantic import BaseModel


class ExtractionDone(Event):
    output: str
    passage: str


class ValidationErrorEvent(Event):
    error: str
    wrong_output: str
    passage: str


class Car(BaseModel):
    brand: str
    model: str
    power: int


class CarCollection(BaseModel):
    cars: list[Car]


EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------

Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}

"""

REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------

This caused the JSON decode error: {error}

Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""


class ReflectionWorkflow(Workflow):
    max_retries: int = 1

    @step
    async def extract(
            self, ctx: Context, ev: StartEvent | ValidationErrorEvent
    ) -> StopEvent | ExtractionDone:
        passage = None
        reflection_prompt = None
        current_retries = await ctx.get("retries", default=0)
        if current_retries >= self.max_retries:
            return StopEvent(result="Max retries reached")
        else:
            await ctx.set("retries", current_retries + 1)

        if isinstance(ev, StartEvent):
            passage = ev.get("passage")
            if not passage:
                return StopEvent(result="Please provide some text in input")
            reflection_prompt = ""
        elif isinstance(ev, ValidationErrorEvent):
            passage = ev.passage
            reflection_prompt = REFLECTION_PROMPT.format(
                wrong_answer=ev.wrong_output, error=ev.error
            )

        llm = DashScope(
            model_name=DashScopeGenerationModels.QWEN_PLUS,
            api_key="sk-your-api-key",
            max_tokens=512
        )
        prompt = EXTRACTION_PROMPT.format(
            passage=passage, schema=CarCollection.model_json_schema()
        )
        if reflection_prompt:
            prompt += reflection_prompt

        output = await llm.acomplete(prompt)

        return ExtractionDone(output=str(output), passage=passage)

    @step
    async def validate(
            self, ev: ExtractionDone
    ) -> StopEvent | ValidationErrorEvent:
        try:
            CarCollection.model_validate_json(ev.output)
        except Exception as e:
            print("Validation failed, retrying...")
            return ValidationErrorEvent(
                error=str(e), wrong_output=ev.output, passage=ev.passage
            )

        return StopEvent(result=ev.output)


async def main():
    w = ReflectionWorkflow(timeout=120, verbose=True)

    # Run the workflow
    ret = await w.run(
        passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
    )
    print("-------------------------------------------------------------------------")
    print(ret)


asyncio.run(main())

  • 替换代码中的api_key="sk-your-api-key" 为你的阿里云DashScope的key
;