pythonimport numpy as np
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
# 初始化模型和处理器
processor = AutoProcessor.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384")
model = AutoModel.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384")
# 将模型移动到 GPU 上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 提取图片特征的函数
def get_image_features(image: Image.Image):
"""
输入一张图片,输出其特征向量。
:param image: 输入图片 (PIL.Image)
:return: 图片特征向量 (list)
"""
# 处理图像输入,使其可以输入模型
inputs = processor(images=image, return_tensors="pt").to(device) # 将输入张量移动到设备上
# 不计算梯度 (inference 模式)
with torch.no_grad():
# 提取图片特征
image_features = model.get_image_features(**inputs)
# 对特征进行 L2 归一化
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
# 转换为 CPU 上的 list
return image_features.cpu().tolist()
# 示例用法
if __name__ == "__main__":
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
features = get_image_features(image)
print(np.asarray(features).shape)
print("Image Features:", features)
pythonimport json
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
import traceback
import time
from typing import Dict, Optional
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Field
from PIL import Image
from io import BytesIO
# 初始化模型和处理器
processor = AutoProcessor.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384")
model = AutoModel.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384")
# 将模型移动到 GPU 上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 提取图片特征的函数
def get_image_features(image: Image.Image):
"""
输入一张图片,输出其特征向量。
:param image: 输入图片 (PIL.Image)
:return: 图片特征向量 (list)
"""
# 处理图像输入,使其可以输入模型
inputs = processor(images=image, return_tensors="pt").to(device) # 将输入张量移动到设备上
# 不计算梯度 (inference 模式)
with torch.no_grad():
# 提取图片特征
image_features = model.get_image_features(**inputs)
# 对特征进行 L2 归一化
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
# 转换为 CPU 上的 list
return image_features.cpu().tolist()
app = FastAPI()
class CustomHTTPException():
"""自定义异常"""
def __init__(self, description: str, error_code: int, detail: str):
self.description = description
self.error_code = error_code
self.detail = detail
def to_response_dict(self):
"""格式化响应数据"""
return {
"description": self.description,
"content": {
"application/json": {
"example": {"code": self.error_code,
"data": {},
"message": self.detail,
}
}
},
}
def __call__(self, extra_detail=None, code=500):
"""自定义异常响应"""
if extra_detail is not None:
detailx = json.dumps({"detail": self.detail, "extra_detail": extra_detail},
indent=4, ensure_ascii=False)
else:
detailx = self.detail
return JSONResponse(content={"code": code,
"data": {},
"message": detailx,
},
status_code=self.error_code)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
"""处理请求校验异常"""
message = ""
for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return JSONResponse(content={"code": 422,
"data": {},
"message": "输入参数不符合要求: " + str(message)},
status_code=422)
ErrorResponseFeatureExtraction = CustomHTTPException("Feature Extraction API Error", 505, "# 图片特征提取出错\n")
class ImageFeatureResponse(BaseModel):
"""图像特征接口返回模型"""
code: int
message: str
features: Optional[Dict] = Field(None, description="提取的图片特征")
requestId: Optional[str] = Field(None, description="请求 ID")
@app.post("/extract_features", response_model=ImageFeatureResponse,
responses={505: ErrorResponseFeatureExtraction.to_response_dict()})
async def extract_image_features(
requestId: Optional[str] = Form(None, description="请求 ID"),
file: UploadFile = File(..., description="上传的图片文件")
):
try:
start_time = time.time()
# 检查文件格式
if "image" not in file.content_type:
return ErrorResponseFeatureExtraction(
extra_detail=f"# 文件格式不正确,文件类型应为图片,实际为 {file.content_type}"
)
# 读取图像数据
image_bytes = await file.read()
image = Image.open(BytesIO(image_bytes)).convert("RGB")
# 提取图像特征
features = get_image_features(image)
end_time = time.time()
return {
"code": 200,
"message": "Success",
"features": {"feature": features},
"requestId": requestId,
"data": {
"processing_time": round(end_time - start_time, 3)
}
}
except Exception as e:
return ErrorResponseFeatureExtraction(
extra_detail=f"# 错误信息: {str(traceback.format_exc())}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7898)
pythonimport requests
import os
# 定义下载图片和请求接口的函数
def download_image(url, save_path):
"""
下载图片到本地
:param url: 图片的URL
:param save_path: 保存图片的路径
"""
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(save_path, 'wb') as file:
for chunk in response.iter_content(1024):
file.write(chunk)
print(f"图片已保存到: {save_path}")
else:
print(f"图片下载失败, 状态码: {response.status_code}")
raise Exception("图片下载失败")
def send_request(server_url, file_path, request_id=None):
"""
向特征提取接口发送 POST 请求
:param server_url: 接口地址
:param file_path: 本地图片路径
:param request_id: 请求 ID(可选)
:return: 服务器返回的响应
"""
with open(file_path, 'rb') as file:
files = {'file': (os.path.basename(file_path), file, 'image/jpeg')}
data = {'requestId': request_id} if request_id else {}
response = requests.post(server_url, files=files, data=data)
return response
if __name__ == "__main__":
# 图片下载 URL 和保存路径
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_path = "./000000039769.jpg"
# API 接口 URL
api_url = "http://127.0.0.1:7898/extract_features" # 根据实际部署地址修改
try:
# 1. 下载图片到本地
download_image(image_url, image_path)
# 2. 发送请求到接口
request_id = "example_request_id"
response = send_request(api_url, image_path, request_id=request_id)
# 3. 打印服务器返回结果
if response.status_code == 200:
print("接口返回结果:")
print(response.json())
else:
print(f"接口请求失败, 状态码: {response.status_code}")
print("错误信息:", response.text)
except Exception as e:
print("发生错误:", str(e))
multipart/form-data
上传的图片文件,支持的图片格式包括 JPEG
, PNG
等。示例请求:
bashcurl -X POST "http://<server_ip>:<port>/extract_features" \
-H "Content-Type: multipart/form-data" \
-F "file=@/path/to/example.jpg" \
-F "requestId=feature_extraction_request"
200
,表示成功。"Success"
。requestId
)。示例响应(成功):
json{
"code": 200,
"message": "Success",
"features": {
"feature": [
0.245123,
0.003452,
-0.114215,
...,
0.451223
]
},
"requestId": "feature_extraction_request",
"data": {
"processing_time": 0.289
}
}
505
,表示内部错误。{}
。示例响应(异常):
json{
"code": 505,
"data": {},
"message": "# 图片特征提取出错\n# 错误信息: Traceback (most recent call last):\n File \"/path/to/app.py\", line 120, in extract_image_features\n image = Image.open(BytesIO(image_bytes)).convert(\"RGB\")\n ...\nUnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7f8e0b89>",
}
422
,表示请求参数校验错误。{}
。示例响应(参数错误):
json{
"code": 422,
"data": {},
"message": "输入参数不符合要求: file:缺少上传的图片; requestId:值应该是字符串;"
}
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!