Bootstrap

百度商业AI技术创新大赛赛道二:广告图片描述生成 chat.py源码解析

from PIL import Image
from io import BytesIO
import base64

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # 允许加载截断的图像

from auto import (
    AutoConfigMIX,
    AutoModelMIX,
    AutoProcessorMIX,
    AutoTokenizerMIX,
)
from utils.log import logger

seed = 24
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
dtype = "bfloat16"
#dtype = "float16"
if not paddle.amp.is_bfloat16_supported():
    logger.warning("bfloat16 is not supported on your device,change to float16")
    dtype = "float16"

model_name_or_path = sys.argv[1]
tokenizer = AutoTokenizerMIX.from_pretrained(model_name_or_path)
processor, _ = AutoProcessorMIX.from_pretrained(model_name_or_path)
model_config = AutoConfigMIX.from_pretrained(model_name_or_path, dtype=dtype)
model = AutoModelMIX.from_pretrained(model_name_or_path, config=model_config, dtype=dtype)
model.eval()

prompt = "请描述图片内容"
start = time.time()
query1 = [
    {"image": "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg"},
    {"text": prompt},
]
input = processor(query=query1, return_tensors="pd")
query1 = tokenizer.from_list_format(query1)
response, history = model.chat(tokenizer, query=query1, history=None, images=input["images"])
response = response.replace("\n", " ").replace("\r", " ")
print("prompt: %s" % prompt)
print("response: %s" % response)
print("------------------")
    
end = time.time()
length = end - start
print("It took", length, "seconds!")

我无法查看或分析图像。不过,我可以逐行解释你提供的代码。以下是代码的逐行解释:

python

复制

from PIL import Image

导入PIL库中的Image模块,用于处理图像。

python

复制

from io import BytesIO

导入BytesIO模块,以便在内存中处理字节数据

python

复制

import base64

导入base64模块,用于处理Base64编码和解码。

python

复制

from PIL import Image, ImageFile

再次导入PIL库中的Image和ImageFile模块,可能是为了使用ImageFile的特定功能

python

复制

ImageFile.LOAD_TRUNCATED_IMAGES = True  # 允许加载截断的图像

设置ImageFile的属性,允许加载不完整的(截断的)图像。

python

复制

from auto import (
    AutoConfigMIX,
    AutoModelMIX,
    AutoProcessorMIX,
    AutoTokenizerMIX,
)

auto模块导入几个类,用于模型配置、模型处理、模型分词器的自动加载

python

复制

from utils.log import logger

从utils.log模块导入logger,用于记录日志信息。

python

复制

seed = 24

设置随机数种子为24,以确保结果的可复现性。

python

复制

paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)

为不同的库(Paddle、random和NumPy)设置相同的随机种子。

python

复制

dtype = "bfloat16"

设置数据类型为bfloat16。

python

复制

#dtype = "float16"

注释掉的代码,可能是备用选项。

python

复制

if not paddle.amp.is_bfloat16_supported():

检查当前设备是否支持bfloat16数据类型。

python

复制

    logger.warning("bfloat16 is not supported on your device,change to float16")

如果不支持,记录警告信息。

python

复制

    dtype = "float16"

将数据类型更改为float16。

python

复制

model_name_or_path = sys.argv[1]

命令行参数获取模型名称或路径。

python

复制

tokenizer = AutoTokenizerMIX.from_pretrained(model_name_or_path)

加载预训练的分词器。

python

复制

processor, _ = AutoProcessorMIX.from_pretrained(model_name_or_path)

加载预训练的处理器,忽略返回的第二个值。

python

复制

model_config = AutoConfigMIX.from_pretrained(model_name_or_path, dtype=dtype)

加载模型配置,并指定数据类型。

python

复制

model = AutoModelMIX.from_pretrained(model_name_or_path, config=model_config, dtype=dtype)

加载预训练模型,并使用之前加载的配置和数据类型。

python

复制

model.eval()

将模型设置为评估模式,禁用dropout等训练模式下的特性。

python

复制

prompt = "请描述图片内容"

定义一个提示字符串,用于生成描述。

python

复制

start = time.time()

记录当前时间,开始计时。

python

复制

query1 = [
    {"image": "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg"},
    {"text": prompt},
]

构建查询列表,包括图片链接和文本提示。

python

复制

input = processor(query=query1, return_tensors="pd")

使用处理器处理查询,并返回张量格式的数据。

python

复制

query1 = tokenizer.from_list_format(query1)

将查询转换为分词器所需的格式。

python

复制

response, history = model.chat(tokenizer, query=query1, history=None, images=input["images"])

调用模型的聊天接口,获取响应和历史记录。

python

复制

response = response.replace("\n", " ").replace("\r", " ")

清理响应文本,去除换行符。

python

复制

print("prompt: %s" % prompt)

打印提示文本。

python

复制

print("response: %s" % response)

打印模型的响应。

python

复制

print("------------------")

打印分隔线。

python

复制

end = time.time()

记录结束时间。

python

复制

length = end - start

计算处理所用的时间。

python

复制

print("It took", length, "seconds!")

打印处理所用的时间。

;