2024-12-10
深度学习
00

目录

推理图片:
fastapi:
客户端请求
接口文档
提取图像特征接口 POST /extract_features
请求体
响应
HTTP 状态码==200
HTTP 状态码==505
HTTP 状态码==422

推理图片:

python
import numpy as np from PIL import Image import requests from transformers import AutoProcessor, AutoModel import torch # 初始化模型和处理器 processor = AutoProcessor.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384") model = AutoModel.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384") # 将模型移动到 GPU 上 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 提取图片特征的函数 def get_image_features(image: Image.Image): """ 输入一张图片,输出其特征向量。 :param image: 输入图片 (PIL.Image) :return: 图片特征向量 (list) """ # 处理图像输入,使其可以输入模型 inputs = processor(images=image, return_tensors="pt").to(device) # 将输入张量移动到设备上 # 不计算梯度 (inference 模式) with torch.no_grad(): # 提取图片特征 image_features = model.get_image_features(**inputs) # 对特征进行 L2 归一化 image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) # 转换为 CPU 上的 list return image_features.cpu().tolist() # 示例用法 if __name__ == "__main__": url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) features = get_image_features(image) print(np.asarray(features).shape) print("Image Features:", features)

fastapi:

python
import json from PIL import Image import requests from transformers import AutoProcessor, AutoModel import torch import traceback import time from typing import Dict, Optional from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from pydantic import BaseModel, Field from PIL import Image from io import BytesIO # 初始化模型和处理器 processor = AutoProcessor.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384") model = AutoModel.from_pretrained("/ssd/xiedong/siglip-so400m-patch14-384") # 将模型移动到 GPU 上 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 提取图片特征的函数 def get_image_features(image: Image.Image): """ 输入一张图片,输出其特征向量。 :param image: 输入图片 (PIL.Image) :return: 图片特征向量 (list) """ # 处理图像输入,使其可以输入模型 inputs = processor(images=image, return_tensors="pt").to(device) # 将输入张量移动到设备上 # 不计算梯度 (inference 模式) with torch.no_grad(): # 提取图片特征 image_features = model.get_image_features(**inputs) # 对特征进行 L2 归一化 image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) # 转换为 CPU 上的 list return image_features.cpu().tolist() app = FastAPI() class CustomHTTPException(): """自定义异常""" def __init__(self, description: str, error_code: int, detail: str): self.description = description self.error_code = error_code self.detail = detail def to_response_dict(self): """格式化响应数据""" return { "description": self.description, "content": { "application/json": { "example": {"code": self.error_code, "data": {}, "message": self.detail, } } }, } def __call__(self, extra_detail=None, code=500): """自定义异常响应""" if extra_detail is not None: detailx = json.dumps({"detail": self.detail, "extra_detail": extra_detail}, indent=4, ensure_ascii=False) else: detailx = self.detail return JSONResponse(content={"code": code, "data": {}, "message": detailx, }, status_code=self.error_code) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): """处理请求校验异常""" message = "" for error in exc.errors(): message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";" return JSONResponse(content={"code": 422, "data": {}, "message": "输入参数不符合要求: " + str(message)}, status_code=422) ErrorResponseFeatureExtraction = CustomHTTPException("Feature Extraction API Error", 505, "# 图片特征提取出错\n") class ImageFeatureResponse(BaseModel): """图像特征接口返回模型""" code: int message: str features: Optional[Dict] = Field(None, description="提取的图片特征") requestId: Optional[str] = Field(None, description="请求 ID") @app.post("/extract_features", response_model=ImageFeatureResponse, responses={505: ErrorResponseFeatureExtraction.to_response_dict()}) async def extract_image_features( requestId: Optional[str] = Form(None, description="请求 ID"), file: UploadFile = File(..., description="上传的图片文件") ): try: start_time = time.time() # 检查文件格式 if "image" not in file.content_type: return ErrorResponseFeatureExtraction( extra_detail=f"# 文件格式不正确,文件类型应为图片,实际为 {file.content_type}" ) # 读取图像数据 image_bytes = await file.read() image = Image.open(BytesIO(image_bytes)).convert("RGB") # 提取图像特征 features = get_image_features(image) end_time = time.time() return { "code": 200, "message": "Success", "features": {"feature": features}, "requestId": requestId, "data": { "processing_time": round(end_time - start_time, 3) } } except Exception as e: return ErrorResponseFeatureExtraction( extra_detail=f"# 错误信息: {str(traceback.format_exc())}" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7898)

客户端请求

python
import requests import os # 定义下载图片和请求接口的函数 def download_image(url, save_path): """ 下载图片到本地 :param url: 图片的URL :param save_path: 保存图片的路径 """ response = requests.get(url, stream=True) if response.status_code == 200: with open(save_path, 'wb') as file: for chunk in response.iter_content(1024): file.write(chunk) print(f"图片已保存到: {save_path}") else: print(f"图片下载失败, 状态码: {response.status_code}") raise Exception("图片下载失败") def send_request(server_url, file_path, request_id=None): """ 向特征提取接口发送 POST 请求 :param server_url: 接口地址 :param file_path: 本地图片路径 :param request_id: 请求 ID(可选) :return: 服务器返回的响应 """ with open(file_path, 'rb') as file: files = {'file': (os.path.basename(file_path), file, 'image/jpeg')} data = {'requestId': request_id} if request_id else {} response = requests.post(server_url, files=files, data=data) return response if __name__ == "__main__": # 图片下载 URL 和保存路径 image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" image_path = "./000000039769.jpg" # API 接口 URL api_url = "http://127.0.0.1:7898/extract_features" # 根据实际部署地址修改 try: # 1. 下载图片到本地 download_image(image_url, image_path) # 2. 发送请求到接口 request_id = "example_request_id" response = send_request(api_url, image_path, request_id=request_id) # 3. 打印服务器返回结果 if response.status_code == 200: print("接口返回结果:") print(response.json()) else: print(f"接口请求失败, 状态码: {response.status_code}") print("错误信息:", response.text) except Exception as e: print("发生错误:", str(e))

接口文档

提取图像特征接口 POST /extract_features

  • 概述:提取给定图像的特征向量。
  • 描述:此接口接受上传的图片文件,提取其高维特征向量,用于后续的图像检索及分析。

请求体

  • file (文件): 必须,通过 multipart/form-data 上传的图片文件,支持的图片格式包括 JPEG, PNG 等。
  • requestId (字符串, 可选): 每个请求的唯一标识符,接口会原样返回。

示例请求

bash
curl -X POST "http://<server_ip>:<port>/extract_features" \ -H "Content-Type: multipart/form-data" \ -F "file=@/path/to/example.jpg" \ -F "requestId=feature_extraction_request"

响应

HTTP 状态码==200

  • 响应内容
    • code (整数): 固定为 200,表示成功。
    • message (字符串): 成功提示,固定为 "Success"
    • features (字典): 包含提取的图像特征向量。
      • feature (数组): 图像的特征向量,每一项为浮点数(高维特征向量,可能包含上百维)。
    • requestId (字符串, 可选): 请求的唯一标识符(如果输入了 requestId)。
    • data (字典, 可选):
      • processing_time (浮点数): 处理耗时,单位为秒,表示从接收到请求到返回特征向量所消耗的时间。

示例响应(成功)

json
{ "code": 200, "message": "Success", "features": { "feature": [ 0.245123, 0.003452, -0.114215, ..., 0.451223 ] }, "requestId": "feature_extraction_request", "data": { "processing_time": 0.289 } }

HTTP 状态码==505

  • 响应内容
    • code (整数): 固定为 505,表示内部错误。
    • data (字典): 固定为空字典 {}
    • message (字符串): 包含错误信息,包括基础错误提示和详细堆栈信息(如果有)。

示例响应(异常)

json
{ "code": 505, "data": {}, "message": "# 图片特征提取出错\n# 错误信息: Traceback (most recent call last):\n File \"/path/to/app.py\", line 120, in extract_image_features\n image = Image.open(BytesIO(image_bytes)).convert(\"RGB\")\n ...\nUnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7f8e0b89>", }

HTTP 状态码==422

  • 响应内容
    • code (整数): 固定为 422,表示请求参数校验错误。
    • data (字典): 固定为空字典 {}
    • message (字符串): 输入参数验证失败的详细描述。

示例响应(参数错误)

json
{ "code": 422, "data": {}, "message": "输入参数不符合要求: file:缺少上传的图片; requestId:值应该是字符串;" }
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!