AIMO Python SDK 开发者指南
简介
AIMO Python SDK
是阿加犀 AI 模型优化平台 (AI Model Optimizer, 简称 AIMO) 的 Python 客户端工具包。开发者可以通过该工具包高效的访问 AIMO 的服务,以及与其它系统实现自动化集成。此工具包提供了任务创建、参数设置、任务控制、状态查询、获取结果及下载优化后的模型和相关文件等功能。
使用流程
使用 AIMO Python SDK 的基本流程如下图所示:
- 登录 AIMO:开发者需使用 AIMO 上的个人 API Key 登录到 AIMO 平台后,方可访问对应的功能
- 创建任务:创建优化任务,设置输入信息,例如预训练模型框架,文件路径等信息
- 选择部署平台:确定模型需要部署的芯片平台,包括芯片厂商,芯片型号,模型推理框架
- 设置优化参数:设置模型优化的处理参数,例如是否量化,量化数据精度,采用的量化算法等
- 提交任务:提交任务到 AIMO 平台进行自动处理
- 下载模型:优化任务执行成功后,即可下载优化后的模型文件及其它相关文件
准备开发环境
Python 环境
请确认已安装 Python 3.9 或更高版本。
安装 AIMO Python SDK
可通过 pip 安装此 SDK:
pip install aplux_aimo -i https://mirrors.aidlux.com
💡注意
请确保 pip 是最新版本,以避免潜在的安装问题。
获取 AIMO API Key
💡注意
请确认已注册阿加犀开发者帐号,并能成功登录 AIMO
登录 AIMO 后,点击右上角用户图标弹出下拉菜单,然后点击 “用户密钥”。在弹出的对话框中,将显示 API Key 信息。
快速上手
以下是一个基本的使用 AIMO Python SDK 完成模型优化任务的示例:
from aplux_aimo import AimoApi
from aplux_aimo.enums import SourceModelType, TargetDevice, ModelRuntime, TaskStatus, DownloadFileMode
from aplux_aimo.base_data import QuantizeOptions
from aplux_aimo.enums import ModelDataPrecision, QuantizeMode, CalibrationDataMode, CalibrationDatasetType
# 1. 初始化 AimoApi
# 默认指向公有云服务 "https://aimo.aidlux.com/api/"
aimo = AimoApi()
# 2. 通过API Key 登录 (请替换为你的 API Key)
# 为了安全,建议从环境变量或配置文件读取 API Key
api_key = "YOUR_API_KEY"
try:
aimo.login(api_key=api_key)
print("登录成功!")
except Exception as e:
print(f"登录失败: {e}")
exit()
# 3. 创建新任务
try:
print("正在创建任务...")
task = aimo.new_task(
# 指定源模型框架及模型文件路径
source_model_type=SourceModelType.ONNX,
source_model_file=r"C:\\path\\to\\your\\model.onnx",
# 设置要部署模型的设备及推理框架
target_device = TargetDevice.Qualcomm_QCS6490,
target_runtime = ModelRuntime.QNN_2_28,
description="我的第一个 AIMO 任务"
)
print(f"任务创建成功,任务 ID: {task.task_id}")
# 4. (可选) 配置任务参数,例如量化
# task.config.quantize_options = QuantizeOptions(
# quantize_precision=ModelDataPrecision.INT8,
# quantize_mode=QuantizeMode.Enhanced_CLE,
# enable_per_channel_quantize=True,
# calibration_data_mode=CalibrationDataMode.Image,
# calibration_dataset_type=CalibrationDatasetType.Custom,
# calibration_image_color_mode="rgb",
# calibration_image_mean=[123.675, 116.28, 103.53],
# calibration_image_std=[58.395, 57.12, 57.375]
# )
# print("量化参数已配置。")
# 5. (可选) 如果配置了量化且使用自定义校准集,上传校准数据
# if task.config.quantize_options and \
# task.config.quantize_options.calibration_dataset_type == CalibrationDatasetType.Custom and \
# task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Image:
# calibration_files = [
# r"C:\\path\\to\\your\\image1.png", # 请替换为您的校准图片路径
# r"C:\\path\\to\\your\\image2.jpg" # 请替换为您的校准图片路径
# ]
# task.upload_calibration_data_files(files=calibration_files)
# print("校准数据上传成功。")
# 6. 提交任务
print("正在提交任务...")
task.submit()
print("任务提交成功!")
# 7. 轮询任务状态
print("正在轮询任务状态...")
final_status = task.poll_status(interval=10, timeout=3600) # 每10秒查询一次,最长等待1小时
print(f"任务最终状态: {final_status.value}")
# 8. 获取任务结果
if final_status == TaskStatus.SUCCESS:
print("任务成功,正在获取结果...")
result_info = task.get_result()
print(f"任务结果消息: {result_info.message}")
# print(f"任务配置详情: {result_info.task_config.to_custom_dict()}") # 查看详细配置
# 9. 下载结果文件 (例如,下载优化后的模型)
print("正在下载输出模型...")
download_path = r"C:\\path\\to\\save\\downloads" # 请替换为您的下载保存路径
import os
if not os.path.exists(download_path):
os.makedirs(download_path)
saved_file = task.download(
file_mode=DownloadFileMode.OutputModel,
output_file_path=download_path
)
print(f"输出模型已下载到: {saved_file}")
else:
print(f"任务失败或未成功完成,状态: {final_status.value}")
# 可以获取日志信息排查问题
# error_log = task.get_info_log()
# print(f"任务日志: {error_log}")
except Exception as e:
print(f"操作过程中发生错误: {e}")
# 如果 task 对象已创建,可以尝试获取日志
if 'task' in locals() and task:
try:
print(f"尝试获取任务日志: {task.get_info_log()}")
except Exception as log_e:
print(f"获取日志失败: {log_e}")
SDK API 参考
AimoApi
类 (aplux_aimo.client.AimoApi
)
AimoApi
是与 AIMO 服务进行交互的入口点,负责用户认证和任务的创建与获取。
初始化 AimoApi
客户端
- 参数:
base_url
(str, 可选): AIMO 服务的 API 地址。默认为"https://aimo.aidlux.com/api/"
login(api_key: str)
使用 API Key 登录 AIMO 服务。
- 参数:
api_key
(str): 您的 AIMO API Key。
- 异常:
NotApiKeyError
: 如果api_key
为空。NetworkError
: 如果网络连接失败。APIRequestError
: 如果 API 请求返回错误状态码。
new_task(source_model_type: SourceModelType | str, target_device: TargetDevice | str, target_runtime: ModelRuntime | str, source_model_file: Union[List[str], str] = None, description: str = "", **kwargs) -> AimoTask
创建一个新的优化任务。需要参考 SourceModelType
支持详情 和 TargetDevice
支持详情。
- 参数:
source_model_type
(SourceModelType
| str): 源模型类型,例如SourceModelType.ONNX
或字符串"onnx"
。target_device
(TargetDevice
| str): 目标硬件设备,例如TargetDevice.Qualcomm_QCS6490
或字符串"sm7325"
。target_runtime
(ModelRuntime
| str): 目标模型运行时/格式,例如ModelRuntime.QNN_2_28
或字符串"qnn_2.28"
。source_model_file
(Union[List[str], str], 可选): 源模型文件的本地路径或路径列表。如果提供了此参数,SDK 会自动上传模型文件,如果不提供此参数则提供source_model_urls
参数。description
(str, 可选): 任务的描述信息。**kwargs
: 其他可选的任务配置参数,会透传给TaskConfig
。例如quantize_options=QuantizeOptions(...)
。
- 返回:
AimoTask
实例。 - 异常:
ParameterError
: 如果必要参数缺失或无效。FileNotExistError
: 如果source_model_file
路径无效。APIRequestError
: 如果 API 请求失败。
task(task_id: str) -> AimoTask
获取一个已存在的任务实例。
- 参数:
task_id
(str): 已存在的任务 ID。
- 返回:
AimoTask
实例。 - 异常:
TaskNotExistError
: 如果任务 ID 不存在 (在后续调用任务方法时可能抛出)。
AimoTask
类 (aplux_aimo.client.AimoTask
)
AimoTask
类用于管理单个优化任务,包括配置、上传文件、提交、监控和获取结果等操作。
- 参数:
api_key
(str): AIMO API Key。task_id
(str): 任务的唯一标识符。base_url
(str, 可选): AIMO 服务的 API 地址。
属性
task_id: str
: 任务的唯一 ID。api_key: str
: 用于认证的 API Key。base_url: str
: AIMO 服务的基础 URL。task_config: Optional[TaskConfig]
: 任务的配置信息。可以通过task.config.attribute_name = value
的方式直接修改,或通过set_config()
方法设置。result_info: Optional[ResultTaskInfo]
: 任务完成后的结果信息。在调用get_result()
后填充。task_status: TaskStatus
: 当前任务的执行状态。accuracy_eval_task_status: AccuracyEvalTaskStatus
: 当前精度分析任务的状态。
upload_model(source_model_type: SourceModelType | str, source_model_file: Union[List[str], str], description: str = "", timeout=3600)
上传模型文件到 AIMO。通常在 AimoApi.new_task()
中如果提供了 source_model_file
会自动调用。如果创建任务时未上传模型,可以手动调用此方法。
- 参数:
source_model_type
(SourceModelType
| str): 源模型类型。source_model_file
(Union[List[str], str]): 源模型文件的本地路径或路径列表。description
(str, 可选): 任务描述。timeout
(int, 可选): 上传超时时间(秒),默认为 3600 秒 (1小时)。
- 返回: API 响应数据 (dict)。
- 异常:
FileNotFoundError
: 如果模型文件未找到。APIRequestError
: 如果 API 请求失败。- 此方法会更新
task.config.source_model_urls
和task.config.input_output_params
。
upload_calibration_data_files(files: list[str] | str, timeout=3600)
为量化任务上传校准数据集文件(通常是图片)。 仅当 task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Image
且 task.config.quantize_options.calibration_dataset_type == CalibrationDatasetType.Custom
时有效。
- 参数:
files
(list[str] | str): 校准数据集的文件路径列表或单个文件路径/文件夹路径。timeout
(int, 可选): 上传超时时间(秒),默认为 3600 秒 (1小时)。
- 返回: API 响应数据 (dict)。
- 异常:
FileNotFoundError
: 如果校准文件未找到。APIRequestError
: 如果 API 请求失败。- 此方法会更新
task.config.quantize_options.calibration_data_files
。
upload_raw_calibration_data_files(raw_calibration_data: list[dict], timeout=3600)
为量化任务上传原始(raw)格式的校准数据。 仅当 task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Raw
时有效。
- 参数:
raw_calibration_data
(list[dict]): 一个列表,每个元素是一个字典,描述一个输入节点的校准数据。 示例:python[ { "input_name": "input.1", # 对应模型的一个输入节点名 "files": ["/path/to/input1_data1.raw", "/path/to/input1_data2.raw"] # 该输入节点的校准文件列表 }, { "input_name": "input.2", "files": ["/path/to/input2_data1.raw"] } ]
timeout
(int, 可选): 上传超时时间(秒),默认为 3600 秒 (1小时)。
- 异常:
ParameterError
: 如果input_node
数据未找到或格式错误。FileNotFoundError
: 如果校准文件未找到。APIRequestError
: 如果 API 请求失败。- 此方法会更新
task.config.input_output_params.input_nodes
中对应节点的files
属性。
is_exist() -> bool
检查任务在服务端是否真实存在。
- 返回:
bool
,如果任务存在则为True
,否则为False
。 - 异常:
APIRequestError
: 如果 API 请求失败。
submit() -> bool
提交任务到 AIMO 服务执行优化/转换。 在提交前,请确保所有必要的配置(如 target_device
, target_runtime
, source_model_urls
or source_model_file
)已设置,并且如果需要量化,相关的 quantize_options
和校准数据已准备好。
- 返回:
bool
,如果提交请求成功则为True
。 - 异常:
ParameterError
: 如果任务配置不完整(例如 TensorFlow PB 模型缺少输出节点名)。AimoTaskError
: 如果任务提交失败(例如任务状态不允许提交)。APIRequestError
: 如果 API 请求失败。
poll_status(interval: int = 5, timeout: int = 3600) -> TaskStatus
轮询任务的执行状态,直到任务完成(成功或失败)或超时。
- 参数:
interval
(int, 可选): 轮询间隔时间(秒),默认为 5 秒。timeout
(int, 可选): 最大等待时间(秒),默认为 3600 秒 (1小时)。如果超时,将返回当前最后一次获取到的状态。
- 返回:
TaskStatus
枚举成员,表示任务的最终状态。 - 异常:
APIRequestError
: 如果 API 请求在轮询过程中失败。
get_result() -> ResultTaskInfo
获取已完成任务的详细结果信息。
- 返回:
ResultTaskInfo
实例,包含了任务配置、状态、消息、耗时等。 - 异常:
AimoTaskError
: 如果任务尚未完成或获取结果失败。APIRequestError
: 如果 API 请求失败。- 此方法会填充
task.result_info
和task.task_status
属性。
download(file_mode: DownloadFileMode = DownloadFileMode.OutputModel, output_file_path: str = os.getcwd()) -> str
下载与任务相关的文件。
- 参数:
file_mode
(DownloadFileMode
, 可选): 要下载的文件类型,默认为DownloadFileMode.OutputModel
(优化后的模型)。也可以是DownloadFileMode.InputModel
(原始输入模型)。output_file_path
(str, 可选): 下载文件保存的本地目录路径。默认为当前工作目录。
- 返回:
str
,下载文件的完整本地路径。 - 异常:
AimoTaskError
: 如果无法获取下载链接或下载失败。APIRequestError
: 如果 API 请求失败。IOError
: 如果文件写入失败。
terminate() -> Dict[str, Any]
请求终止正在执行的任务。
- 返回: API 响应数据 (dict)。
- 异常:
AimoTaskError
: 如果任务终止失败。APIRequestError
: 如果 API 请求失败。
delete() -> Dict[str, Any]
请求删除服务端上的任务。
- 返回: API 响应数据 (dict)。
- 异常:
AimoTaskError
: 如果任务删除失败。APIRequestError
: 如果 API 请求失败。
submit_accuracy_eval() -> bool
提交精度分析任务。在优化任务成功完成后调用。
- 返回:
bool
,如果提交请求成功则为True
。 - 异常:
AimoTaskError
: 如果提交精度分析失败。APIRequestError
: 如果 API 请求失败。
poll_accuracy_eval_status(interval: int = 5, timeout: int = 3600) -> AccuracyEvalTaskStatus
轮询精度分析任务的状态。
- 参数:
interval
(int, 可选): 轮询间隔时间(秒),默认为 5 秒。timeout
(int, 可选): 最大等待时间(秒),默认为 3600 秒 (1小时)。
- 返回:
AccuracyEvalTaskStatus
枚举成员。 - 异常:
APIRequestError
: 如果 API 请求在轮询过程中失败。
get_accuracy_eval_result() -> AccuracyEvalResult
获取精度分析的结果。
- 返回:
AccuracyEvalResult
实例。 - 异常:
AimoTaskError
: 如果获取精度分析结果失败。APIRequestError
: 如果 API 请求失败。- 此方法会更新
task.accuracy_eval_task_status
。
get_info_log() -> str
获取优化任务的日志信息。
- 返回:
str
,包含任务日志的字符串。 - 异常:
AimoTaskError
: 如果获取日志失败。APIRequestError
: 如果 API 请求失败。
数据类 (aplux_aimo.base_data
)
SDK 使用 dataclass 来组织和传递数据。
QuantizeOptions
量化相关选项:
quantize_precision: ModelDataPrecision | str
: 量化精度,例如ModelDataPrecision.INT8
("A8_W8"),ModelDataPrecision.INT16
("A16_W8")。quantize_mode: QuantizeMode | str
: 量化模式,例如QuantizeMode.Enhanced_CLE
("cle"),QuantizeMode.Enhanced_ADA
("ada")。enable_per_channel_quantize: bool
: 是否启用 Per-Channel 量化,默认为False
。calibration_data_mode: CalibrationDataMode | str
: 校准数据模式,例如CalibrationDataMode.Image
("cv"),CalibrationDataMode.Raw
("nlp"),CalibrationDataMode.Random
("random")。calibration_dataset_type: CalibrationDatasetType | str
: 校准数据集类型,例如CalibrationDatasetType.ImageNet
("imagenet"),CalibrationDatasetType.COCO
("coco"),CalibrationDatasetType.Face
("face"),CalibrationDatasetType.Custom
("")。calibration_data_files: list[str]
: 自定义校准数据集的文件 URL 列表 (由 SDK 内部上传后填充)。calibration_image_color_mode: str
: 校准图像颜色模式 (当calibration_data_mode
为 Image 时),例如 "rgb", "bgr"。默认为 "rgb"。calibration_image_mean: list[float]
: 校准图像均值列表,例如[123.675, 116.28, 103.53]
。calibration_image_std: list[float]
: 校准图像标准差列表,例如[58.395, 57.12, 57.375]
。
InputNode
模型输入节点信息。
node_name: str
: 输入节点名称。dimension: list[int]
: 输入节点形状/维度,例如[1, 3, 224, 224]
。files: list[str]
: 当CalibrationDataMode
为Raw
时,该输入节点对应的校准数据文件 URL 列表,调用upload_raw_calibration_data_files
方法后内部会自动上传校准数据文件到服务端,不需要用户手动上传。
InputOutputParams
模型的输入输出参数。
input_nodes: list[InputNode] | list[dict]
: 输入节点列表,每个元素是InputNode
实例或其字典表示。output_node_name: list[str]
: 输出节点名称列表。
TaskConfig
任务的完整配置。
task_id: str
: 任务 ID。description: str
: 任务描述。target_device: TargetDevice | str
: 目标硬件设备。target_device_name: str
: 目标硬件设备名称 (通常由 SDK 根据target_device
填充)。target_runtime: ModelRuntime | str
: 目标模型运行时/格式。input_output_params: Optional[InputOutputParams]
: 模型的输入输出参数。quantize_options: Optional[QuantizeOptions]
: 量化配置选项。source_model_type: SourceModelType | str
: 源模型类型。source_model_file: Optional[list[str]]
: (本地) 源模型文件路径列表 (主要用于创建任务时 SDK 内部记录,实际提交给服务端的是source_model_urls
)。source_model_urls: list[str]
: 源模型文件在服务端的 URL 列表 (由模型上传后填充)。
AccuracyEvalResult
精度分析结果:
id: str
: 精度分析任务的 ID。status: AccuracyEvalTaskStatus
: 精度分析任务的状态。result_json: list
: 精度分析结果的 JSON 数据。message: str
: 精度分析相关消息。
ResultTaskInfo
任务执行完成后的详细结果信息:
task_id: str
: 任务 ID。task_config: Optional[TaskConfig]
: 完成时的任务配置快照。message: str
: 任务执行结果的消息。create_time: str
: 任务创建时间戳。usage_time: str
: 任务执行耗时。accuracy_eval_result: Optional[AccuracyEvalResult]
: 精度分析结果 (如果执行了)。task_status: TaskStatus | str
: 任务的最终状态。input_model: str
: 输入模型的 URL (通常是 zip 格式)。
UpdateFileInfo
文件上传操作后的返回信息:
file_datas: list[dict]
: 上传成功的每个文件的信息列表,包含url
等。model_info: dict
: 如果上传的是模型文件,这里会包含从模型解析出的信息 (例如输入输出节点)。
枚举类型 (aplux_aimo.enums
)
SDK 定义了多个枚举类型以规范参数。
SourceModelType
源模型类型:
ONNX = "onnx"
PyTorch = "pt"
TensorFlow_PB = "pb"
TensorFlow_Lite = "tflite"
TensorFlow_Save_Model = "pbsm"
Caffe = "caffe"
PaddlePaddle = "pd"
ModelRuntime
目标模型运行时/格式。
ONNX = "onnx"
TFLite = "tflite"
PaddleLite = "nb"
TNN = "tnn"
MNN = "mnn"
NCNN = "ncnn"
MindSpore = "ms"
RKNN_2 = "rknn"
SNPE_1 = "dlc_1.x"
SNPE_2_16 = "dlc_2.16"
SNPE_2_23 = "dlc_2.23"
SNPE_2_29 = "dlc_2.29"
SNPE_2_31 = "dlc_2.31"
QNN_2_16 = "qnn_2.16"
QNN_2_23 = "qnn_2.23"
QNN_2_29 = "qnn_2.29"
QNN_2_31 = "qnn_2.31"
TargetDevice
目标硬件设备。每个枚举成员包含一个 value
(用于 API) 和一个 label
(描述性名称)。
Qualcomm_QCS6490 = ("sm7325", "QCS6490")
Qualcomm_QCS8250 = ("8250", "QCS8250")
Qualcomm_QCS8550 = ("qcs8550", "QCS8550")
Qualcomm_Snapdragon_8_Gen1 = ("sm8450", "Snapdragon 8 Gen1")
Qualcomm_Snapdragon_8_Gen2 = ("sm8550", "Snapdragon 8 Gen2")
Qualcomm_Snapdragon_8_Gen3 = ("sm8650", "Snapdragon 8 Gen3")
Qualcomm_Snapdragon_8_Elite_Gen1 = ("sm8750", "Snapdragon 8 Elite")
Rockchip_RK3588 = ("3588", "RK3588")
使用时,可以直接用枚举成员,SDK 会自动取其 value
,例如 TargetDevice.Qualcomm_QCS6490
。
QuantizeMode
量化模式:
Enhanced_ADA = "ada"
Enhanced_CLE = "cle"
ModelDataPrecision
模型量化数据精度:
INT8 = "A8_W8"
INT16 = "A16_W8"
CalibrationDataMode
校准数据模式 (对应 API 中的 task_type
):
Image = "cv"
(图像数据)Raw = "nlp"
(原始二进制数据,常用于 NLP 模型)Random = "random"
(使用随机生成的数据)
CalibrationDatasetType
预设的校准数据集类型或自定义。
ImageNet = "imagenet"
COCO = "coco"
Face = "face"
Custom = ""
(表示使用用户自定义上传的数据集)
DownloadFileMode
下载文件类型。
InputModel = "input_model"
(下载原始输入模型)OutputModel = "output_model"
(下载优化/转换后的输出模型)
TaskStatus
任务执行状态。
PENDING = "PENDING"
(排队等待)STARTED = "START"
(已开始)SUCCESS = "SUCCESS"
(成功)FAILURE = "FAILURE"
(失败)NOTSUBMIT = "NOTSUBMIT"
(未提交)
AccuracyEvalTaskStatus
精度分析任务状态。
PENDING = "Waiting"
(等待中)STARTED = "Running"
(运行中)SUCCESS = "Success"
(成功)FAILURE = "Failure"
(失败)NOTSUBMIT = "Notsubmit"
(未提交)
高级功能与用例
操作已有任务
如果您知道一个任务的 ID,可以使用 api.task(task_id)
来获取该任务的 AimoTask
实例,然后对其进行操作,例如查询状态、获取结果或下载文件。
# ... (api 初始化和登录同上) ...
known_task_id = "your_existing_task_id" # 替换为已存在的任务ID
try:
existing_task = api.task(task_id=known_task_id)
print(f"获取任务 {existing_task.task_id} 成功。")
# 检查任务是否存在 (可选,因为 task() 本身不直接校验)
if not existing_task.is_exist():
print(f"任务 {existing_task.task_id} 在服务端不存在。")
else:
print(f"当前状态: {existing_task.poll_status(timeout=60).value}") # 短暂轮询获取最新状态
if existing_task.task_status == TaskStatus.SUCCESS:
result = existing_task.get_result()
print(f"任务结果: {result.message}")
# existing_task.download(...)
elif existing_task.task_status == TaskStatus.FAILURE:
print(f"任务失败,日志: {existing_task.get_info_log()}")
except Exception as e:
print(f"操作已有任务时出错: {e}")
模型量化
模型量化通过 TaskConfig
中的 quantize_options
属性进行配置。
详细配置 QuantizeOptions
from aplux_aimo.base_data import QuantizeOptions
from aplux_aimo.enums import ModelDataPrecision, QuantizeMode, CalibrationDataMode, CalibrationDatasetType
# ... (task 创建同上) ...
task.config.quantize_options = QuantizeOptions(
quantize_precision=ModelDataPrecision.INT8, # 或 ModelDataPrecision.INT16
quantize_mode=QuantizeMode.Enhanced_CLE, # 或 QuantizeMode.Enhanced_ADA
enable_per_channel_quantize=True, # 是否启用 per-channel 量化
calibration_data_mode=CalibrationDataMode.Image, # 校准数据类型:Image, Raw, Random
calibration_dataset_type=CalibrationDatasetType.Custom, # 校准数据集:ImageNet, COCO, Face, Custom
calibration_image_color_mode="rgb", # 如果是 Image, 指定颜色模式
calibration_image_mean=[123.675, 116.28, 103.53], # 图像均值
calibration_image_std=[58.395, 57.12, 57.375] # 图像标准差
)
print("量化参数已配置。")
上传图像校准数据集
如果 calibration_dataset_type
设置为 CalibrationDatasetType.Custom
且 calibration_data_mode
为 CalibrationDataMode.Image
,则需要上传校准图片。
# ... (假设 task 和 quantize_options 已配置) ...
if task.config.quantize_options.calibration_dataset_type == CalibrationDatasetType.Custom and \
task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Image:
calibration_image_paths = [
r"/path/to/your/calibration_image1.jpg",
r"/path/to/your/calibration_image2.png",
# ... 更多图片路径
]
# 也可以是一个包含图片的文件夹路径
# calibration_image_paths = r"/path/to/your/calibration_images_folder/"
try:
task.upload_calibration_data_files(files=calibration_image_paths)
print("图像校准数据集上传成功。")
except Exception as e:
print(f"上传图像校准数据失败: {e}")
# 之后再 task.submit()
上传原始校准数据集
如果 calibration_data_mode
为 CalibrationDataMode.Raw
,则需要上传原始格式的校准数据。
# ... (假设 task 和 quantize_options 已配置) ...
if task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Raw:
# 确保 task.config.input_output_params.input_nodes 已经通过模型上传被填充
# 或者在创建任务时通过 **kwargs 传递了 input_output_params
# 示例:假设模型有两个输入节点 "input_node_A" 和 "input_node_B"
# 并且这些节点名存在于 task.config.input_output_params.input_nodes[i].node_name
raw_data_config = [
{
"input_name": "input_node_A", # 替换为实际的输入节点名
"files": [
r"/path/to/input_A_sample1.raw",
r"/path/to/input_A_sample2.raw"
]
},
{
"input_name": "input_node_B", # 替换为实际的输入节点名
"files": [r"/path/to/input_B_sample1.raw"]
}
]
try:
task.upload_raw_calibration_data_files(raw_calibration_data=raw_data_config)
print("原始校准数据集上传成功。")
except Exception as e:
print(f"上传原始校准数据失败: {e}")
# 之后再 task.submit()
模型精准度分析
在优化任务成功完成后,您可以为该任务提交一个模型精准度分析子任务,了解模型每一层的精准度情况。
# ... (假设 task 是一个已成功完成的 AimoTask 实例) ...
if task.task_status == TaskStatus.SUCCESS:
try:
print("正在提交精度分析任务...")
task.submit_accuracy_eval()
print("精度分析任务已提交。")
print("正在轮询精度分析状态...")
accuracy_status = task.poll_accuracy_eval_status(interval=5, timeout=1800) # 等待最多30分钟
print(f"精度分析最终状态: {accuracy_status.value}")
if accuracy_status == AccuracyEvalTaskStatus.SUCCESS:
accuracy_result = task.get_accuracy_eval_result()
print("获取精度分析结果成功:")
print(f" ID: {accuracy_result.id}")
print(f" Message: {accuracy_result.message}")
print(f" Result JSON: {accuracy_result.result_json}") # 详细结果
else:
print(f"精度分析失败或未成功完成。状态: {accuracy_status.value}")
except Exception as e:
print(f"精度分析过程中出错: {e}")
获取任务日志
使用 task.get_info_log()
可以获取任务执行过程中的日志,有助于排查问题。
# ... (假设 task 是一个已提交或已完成的 AimoTask 实例) ...
try:
logs = task.get_info_log()
print("任务日志:")
print(logs)
except Exception as e:
print(f"获取任务日志失败: {e}")
删除任务
可以使用 task.delete()
从服务端删除一个任务。
# ... (假设 task 是一个 AimoTask 实例) ...
try:
response = task.delete()
print(f"任务 {task.task_id} 删除请求已发送。响应: {response}")
except Exception as e:
print(f"删除任务失败: {e}")
错误处理与异常 (aplux_aimo.exceptions
)
SDK 定义了一系列自定义异常来帮助开发者处理特定错误情况。所有自定义异常都继承自 AimoError
。
AimoError
: 所有 AIMO SDK 异常的基类。NotApiKeyError
: API Key 未设置或无效时抛出。message
: "API Key is not set, please call login method first"
TaskNotExistError
: 尝试操作一个在服务端不存在的任务时(或任务ID无效)。message
: "Task does not exist"
ParameterError
: API 调用时参数错误或缺失。message
: "Parameter error" (或更具体的错误信息)
NetworkError
:发生网络连接问题。message
: "Network Error"
APIRequestError
: API 服务端返回错误状态码或非预期响应。message
: "API Request Error" (或更具体的错误信息)
FileNotExistError
: 指定的本地文件路径不存在。message
: "File does not exist"
AimoTaskError
: 通用的任务相关操作错误(例如状态不符、操作失败等)。message
: "Task Error" (或更具体的错误信息)
建议的错误处理示例:
from aplux_aimo import AimoApi
from aplux_aimo.exceptions import NotApiKeyError, ParameterError, APIRequestError, AimoTaskError, FileNotExistError
api = AimoApi()
try:
api.login(api_key="YOUR_API_KEY") # 请替换
task = api.new_task(
source_model_type="onnx",
source_model_file=r"C:\\path\\to\\your\\model.onnx", # 请替换
target_device="sm7325",
target_runtime="qnn_2.28"
)
task.submit()
task.poll_status()
result = task.get_result()
print(f"任务成功: {result.message}")
except NotApiKeyError as e:
print(f"认证失败: {e}")
except FileNotExistError as e:
print(f"文件错误: {e}")
except ParameterError as e:
print(f"参数配置错误: {e}")
except APIRequestError as e:
print(f"API 请求失败: {e}")
# 可以检查 e.args 或 e.message 获取更多服务端返回信息
except AimoTaskError as e:
print(f"任务处理失败: {e}")
if 'task' in locals() and task: # 如果 task 对象存在
try:
print(f"尝试获取任务日志: {task.get_info_log()}")
except Exception as log_e:
print(f"获取日志也失败了: {log_e}")
except Exception as e: # 其他未知错误
print(f"发生未知错误: {e}")
常见问题 (FAQ)
Q1: 如何获取 API Key?A1: 您需要在 AIMO 官方平台 (例如:https://aimo.aidlux.com/) 注册账户,并在右上角点击用户头像然后点击用户密钥获取 API Key。
Q2: 支持哪些模型格式和硬件设备?A2: 请参考本文档中
SourceModelType
(源模型类型) 和TargetDevice
(目标硬件设备) 小节的详细列表。Q3: 量化失败如何排查?A3:
- 检查
QuantizeOptions
配置是否正确,特别是校准模式、数据集类型、精度选择。 - 如果使用自定义校准集,确保校准数据质量良好且格式正确,文件已成功上传。
- 查看
task.get_info_log()
获取详细的转换日志,其中可能包含量化过程中的错误信息。 - 确认模型本身是否适合进行所选精度的量化。
- 检查
Q4: 为什么当SourceModelType模型为TensorFlow(saved model)时,aimo页面上传的文件时文件夹,而sdk上传的文件是zip文件?A4: 这是因为TensorFlow(saved model)模型在AIMO平台上的上传方式与SDK上传方式不同,aimo页面上传后同样是处理成zip包。
附录
SourceModelType
支持详情
下表列出了不同的源模型框架类型 (SourceModelType
) ,所需文件数量,文件类型,可转换的目标推理运行时 (ModelRuntime
):
源模型框架类型 (SourceModelType) | 值 (value) | 文件数量 (file_len) | 支持文件类型 (file_type) | 支持的目标运行时 (ModelRuntime) |
---|---|---|---|---|
ONNX | onnx | 1 | ["onnx"] | TFLite(tflite ), SNPE(dlc_* ), QNN(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
PyTorch | pt | 1 | ["pt"] 注意: 模型文件需包含完整的模型结构和参数 | ONNX(onnx ), TFLite(tflite ), SNPE(dlc_* ), QNN 2.23(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
TensorFlow(frozen pb) | pb | 1 | ["pb"] | ONNX(onnx ), TFLite(tflite ), SNPE(dlc_* ), QNN(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
TensorFlow(saved model) | pbsm | 1 | ["zip"] | ONNX(onnx ), TFLite(tflite ), SNPE(dlc_* ), QNN(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
TensorFlow Lite | tflite | 1 | ["tflite"] | TFLite(tflite ), SNPE(dlc_* ), QNN(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
Caffe | caffe | 2 | [".prototxt", ".caffemodel"] | ONNX(onnx ), TFLite(tflite ), SNPE(dlc_* ), QNN(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
Paddle Paddle | pd | 2 | [".pdmodel", ".pdiparams"] | ONNX(onnx ), TFLite(tflite ), SNPE(dlc_* ), QNN(qnn_* ), RKNN_2(rknn ), Paddle Lite(nb ), TNN(tnn ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ) |
MXNet | MXNet | 2 | [".json", ".params"] | ONNX(onnx ), TFLite(tflite ), Paddle Lite (nb ), TNN (tnn ), MNN (mnn ), NCNN (ncnn ), MindSpore (ms ) |
TargetDevice
支持详情
下表列出了不同的目标硬件设备 (TargetDevice
) 及其支持的目标运行时 (ModelRuntime
) 和支持的量化数据精度选项 (ModelDataPrecision
)
目标设备 (label) | 值 (value) | 支持的目标运行时 (ModelRuntime) | 支持的量化数据精度 (ModelDataPrecision) |
---|---|---|---|
Snapdragon 888 | sm8350 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
Snapdragon 8 Gen1 | sm8450 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
Snapdragon 8 Gen2 | sm8550 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
Snapdragon 8 Gen3 | sm8650 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
Snapdragon 8 Elite | sm8750 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
QCS6490 | sm7325 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
QCS8250 | sm8250 | SNPE_1_x(dlc_1.x ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8 (W8A8 ) |
QCS8550 | qcs8550 | SNPE_2_x(dlc_2.* ), QNN_2_x(qnn_2.* ), TFLite(tflite ), MNN(mnn ), NCNN(ncnn ), MindSpore(ms ), PaddleLite(nb ), TNN(tnn ), ONNX(onnx ) | INT8(W8A8 ), INT16(W8A16 ) |
RK3588 | rk3588 | RKNN_2(rknn ), TFLite(tflite ), ONNX(onnx ) | INT8 (W8A8 ) |
通用 Arm CPU & GPU | "" | TFLite(tflite ), MNN(mnn ),NCNN (ncnn ), MindSpore(ms ), Paddle Lite(nb ), TNN(tnn ), ONNX(onnx ) |