Skip to content

AIMO Python SDK 开发者指南

简介

AIMO Python SDK 是阿加犀 AI 模型优化平台 (AI Model Optimizer, 简称 AIMO) 的 Python 客户端工具包。开发者可以通过该工具包高效的访问 AIMO 的服务,以及与其它系统实现自动化集成。此工具包提供了任务创建、参数设置、任务控制、状态查询、获取结果及下载优化后的模型和相关文件等功能。

使用流程

使用 AIMO Python SDK 的基本流程如下图所示:

  1. 登录 AIMO:开发者需使用 AIMO 上的个人 API Key 登录到 AIMO 平台后,方可访问对应的功能
  2. 创建任务:创建优化任务,设置输入信息,例如预训练模型框架,文件路径等信息
  3. 选择部署平台:确定模型需要部署的芯片平台,包括芯片厂商,芯片型号,模型推理框架
  4. 设置优化参数:设置模型优化的处理参数,例如是否量化,量化数据精度,采用的量化算法等
  5. 提交任务:提交任务到 AIMO 平台进行自动处理
  6. 下载模型:优化任务执行成功后,即可下载优化后的模型文件及其它相关文件

准备开发环境

Python 环境

请确认已安装 Python 3.9 或更高版本。

安装 AIMO Python SDK

可通过 pip 安装此 SDK:

bash
pip install aplux_aimo -i https://mirrors.aidlux.com

💡注意

请确保 pip 是最新版本,以避免潜在的安装问题。

获取 AIMO API Key

💡注意

请确认已注册阿加犀开发者帐号,并能成功登录 AIMO

登录 AIMO 后,点击右上角用户图标弹出下拉菜单,然后点击 “用户密钥”。在弹出的对话框中,将显示 API Key 信息。

快速上手

以下是一个基本的使用 AIMO Python SDK 完成模型优化任务的示例:

python
from aplux_aimo import AimoApi
from aplux_aimo.enums import SourceModelType, TargetDevice, ModelRuntime, TaskStatus, DownloadFileMode
from aplux_aimo.base_data import QuantizeOptions
from aplux_aimo.enums import ModelDataPrecision, QuantizeMode, CalibrationDataMode, CalibrationDatasetType

# 1. 初始化 AimoApi 
# 默认指向公有云服务 "https://aimo.aidlux.com/api/"
aimo = AimoApi() 

# 2. 通过API Key 登录 (请替换为你的 API Key)
# 为了安全,建议从环境变量或配置文件读取 API Key
api_key = "YOUR_API_KEY"
try:
    aimo.login(api_key=api_key)
    print("登录成功!")
except Exception as e:
    print(f"登录失败: {e}")
    exit()

# 3. 创建新任务
try:
    print("正在创建任务...")
    task = aimo.new_task(
        # 指定源模型框架及模型文件路径
        source_model_type=SourceModelType.ONNX,
        source_model_file=r"C:\\path\\to\\your\\model.onnx",  
        # 设置要部署模型的设备及推理框架
        target_device = TargetDevice.Qualcomm_QCS6490,
        target_runtime = ModelRuntime.QNN_2_28,
        description="我的第一个 AIMO 任务"
    )
    print(f"任务创建成功,任务 ID: {task.task_id}")

    # 4. (可选) 配置任务参数,例如量化
    # task.config.quantize_options = QuantizeOptions(
    #     quantize_precision=ModelDataPrecision.INT8,
    #     quantize_mode=QuantizeMode.Enhanced_CLE,
    #     enable_per_channel_quantize=True,
    #     calibration_data_mode=CalibrationDataMode.Image,
    #     calibration_dataset_type=CalibrationDatasetType.Custom,
    #     calibration_image_color_mode="rgb",
    #     calibration_image_mean=[123.675, 116.28, 103.53],
    #     calibration_image_std=[58.395, 57.12, 57.375]
    # )
    # print("量化参数已配置。")

    # 5. (可选) 如果配置了量化且使用自定义校准集,上传校准数据
    # if task.config.quantize_options and \
    #    task.config.quantize_options.calibration_dataset_type == CalibrationDatasetType.Custom and \
    #    task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Image:
    #     calibration_files = [
    #         r"C:\\path\\to\\your\\image1.png", # 请替换为您的校准图片路径
    #         r"C:\\path\\to\\your\\image2.jpg"  # 请替换为您的校准图片路径
    #     ]
    #     task.upload_calibration_data_files(files=calibration_files)
    #     print("校准数据上传成功。")

    # 6. 提交任务
    print("正在提交任务...")
    task.submit()
    print("任务提交成功!")

    # 7. 轮询任务状态
    print("正在轮询任务状态...")
    final_status = task.poll_status(interval=10, timeout=3600) # 每10秒查询一次,最长等待1小时
    print(f"任务最终状态: {final_status.value}")

    # 8. 获取任务结果
    if final_status == TaskStatus.SUCCESS:
        print("任务成功,正在获取结果...")
        result_info = task.get_result()
        print(f"任务结果消息: {result_info.message}")
        # print(f"任务配置详情: {result_info.task_config.to_custom_dict()}") # 查看详细配置
        
        # 9. 下载结果文件 (例如,下载优化后的模型)
        print("正在下载输出模型...")
        download_path = r"C:\\path\\to\\save\\downloads" # 请替换为您的下载保存路径
        import os
        if not os.path.exists(download_path):
            os.makedirs(download_path)
        
        saved_file = task.download(
            file_mode=DownloadFileMode.OutputModel, 
            output_file_path=download_path
        )
        print(f"输出模型已下载到: {saved_file}")
    else:
        print(f"任务失败或未成功完成,状态: {final_status.value}")
        # 可以获取日志信息排查问题
        # error_log = task.get_info_log()
        # print(f"任务日志: {error_log}")


except Exception as e:
    print(f"操作过程中发生错误: {e}")
    # 如果 task 对象已创建,可以尝试获取日志
    if 'task' in locals() and task:
        try:
            print(f"尝试获取任务日志: {task.get_info_log()}")
        except Exception as log_e:
            print(f"获取日志失败: {log_e}")

SDK API 参考

AimoApi 类 (aplux_aimo.client.AimoApi)

AimoApi 是与 AIMO 服务进行交互的入口点,负责用户认证和任务的创建与获取。

初始化 AimoApi 客户端

  • 参数:
    • base_url (str, 可选): AIMO 服务的 API 地址。默认为 "https://aimo.aidlux.com/api/"

login(api_key: str)

使用 API Key 登录 AIMO 服务。

  • 参数:
    • api_key (str): 您的 AIMO API Key。
  • 异常:
    • NotApiKeyError: 如果 api_key 为空。
    • NetworkError: 如果网络连接失败。
    • APIRequestError: 如果 API 请求返回错误状态码。

new_task(source_model_type: SourceModelType | str, target_device: TargetDevice | str, target_runtime: ModelRuntime | str, source_model_file: Union[List[str], str] = None, description: str = "", **kwargs) -> AimoTask

创建一个新的优化任务。需要参考 SourceModelType 支持详情 和 TargetDevice 支持详情。

  • 参数:
    • source_model_type (SourceModelType | str): 源模型类型,例如 SourceModelType.ONNX 或字符串 "onnx"
    • target_device (TargetDevice | str): 目标硬件设备,例如 TargetDevice.Qualcomm_QCS6490 或字符串 "sm7325"
    • target_runtime (ModelRuntime | str): 目标模型运行时/格式,例如 ModelRuntime.QNN_2_28 或字符串 "qnn_2.28"
    • source_model_file (Union[List[str], str], 可选): 源模型文件的本地路径或路径列表。如果提供了此参数,SDK 会自动上传模型文件,如果不提供此参数则提供 source_model_urls 参数。
    • description (str, 可选): 任务的描述信息。
    • **kwargs: 其他可选的任务配置参数,会透传给 TaskConfig。例如 quantize_options=QuantizeOptions(...)
  • 返回: AimoTask 实例。
  • 异常:
    • ParameterError: 如果必要参数缺失或无效。
    • FileNotExistError: 如果 source_model_file 路径无效。
    • APIRequestError: 如果 API 请求失败。

task(task_id: str) -> AimoTask

获取一个已存在的任务实例。

  • 参数:
    • task_id (str): 已存在的任务 ID。
  • 返回: AimoTask 实例。
  • 异常:
    • TaskNotExistError: 如果任务 ID 不存在 (在后续调用任务方法时可能抛出)。

AimoTask 类 (aplux_aimo.client.AimoTask)

AimoTask 类用于管理单个优化任务,包括配置、上传文件、提交、监控和获取结果等操作。

  • 参数:
    • api_key (str): AIMO API Key。
    • task_id (str): 任务的唯一标识符。
    • base_url (str, 可选): AIMO 服务的 API 地址。

属性

  • task_id: str: 任务的唯一 ID。
  • api_key: str: 用于认证的 API Key。
  • base_url: str: AIMO 服务的基础 URL。
  • task_config: Optional[TaskConfig]: 任务的配置信息。可以通过 task.config.attribute_name = value 的方式直接修改,或通过 set_config() 方法设置。
  • result_info: Optional[ResultTaskInfo]: 任务完成后的结果信息。在调用 get_result() 后填充。
  • task_status: TaskStatus: 当前任务的执行状态。
  • accuracy_eval_task_status: AccuracyEvalTaskStatus: 当前精度分析任务的状态。

upload_model(source_model_type: SourceModelType | str, source_model_file: Union[List[str], str], description: str = "", timeout=3600)

上传模型文件到 AIMO。通常在 AimoApi.new_task() 中如果提供了 source_model_file 会自动调用。如果创建任务时未上传模型,可以手动调用此方法。

  • 参数:
    • source_model_type (SourceModelType | str): 源模型类型。
    • source_model_file (Union[List[str], str]): 源模型文件的本地路径或路径列表。
    • description (str, 可选): 任务描述。
    • timeout (int, 可选): 上传超时时间(秒),默认为 3600 秒 (1小时)。
  • 返回: API 响应数据 (dict)。
  • 异常:
    • FileNotFoundError: 如果模型文件未找到。
    • APIRequestError: 如果 API 请求失败。
    • 此方法会更新 task.config.source_model_urlstask.config.input_output_params

upload_calibration_data_files(files: list[str] | str, timeout=3600)

为量化任务上传校准数据集文件(通常是图片)。 仅当 task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Imagetask.config.quantize_options.calibration_dataset_type == CalibrationDatasetType.Custom 时有效。

  • 参数:
    • files (list[str] | str): 校准数据集的文件路径列表或单个文件路径/文件夹路径。
    • timeout (int, 可选): 上传超时时间(秒),默认为 3600 秒 (1小时)。
  • 返回: API 响应数据 (dict)。
  • 异常:
    • FileNotFoundError: 如果校准文件未找到。
    • APIRequestError: 如果 API 请求失败。
    • 此方法会更新 task.config.quantize_options.calibration_data_files

upload_raw_calibration_data_files(raw_calibration_data: list[dict], timeout=3600)

为量化任务上传原始(raw)格式的校准数据。 仅当 task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Raw 时有效。

  • 参数:
    • raw_calibration_data (list[dict]): 一个列表,每个元素是一个字典,描述一个输入节点的校准数据。 示例:
      python
      [
          {
              "input_name": "input.1", # 对应模型的一个输入节点名
              "files": ["/path/to/input1_data1.raw", "/path/to/input1_data2.raw"] # 该输入节点的校准文件列表
          },
          {
              "input_name": "input.2",
              "files": ["/path/to/input2_data1.raw"]
          }
      ]
    • timeout (int, 可选): 上传超时时间(秒),默认为 3600 秒 (1小时)。
  • 异常:
    • ParameterError: 如果 input_node 数据未找到或格式错误。
    • FileNotFoundError: 如果校准文件未找到。
    • APIRequestError: 如果 API 请求失败。
    • 此方法会更新 task.config.input_output_params.input_nodes 中对应节点的 files 属性。

is_exist() -> bool

检查任务在服务端是否真实存在。

  • 返回: bool,如果任务存在则为 True,否则为 False
  • 异常:
    • APIRequestError: 如果 API 请求失败。

submit() -> bool

提交任务到 AIMO 服务执行优化/转换。 在提交前,请确保所有必要的配置(如 target_device, target_runtime, source_model_urls or source_model_file)已设置,并且如果需要量化,相关的 quantize_options 和校准数据已准备好。

  • 返回: bool,如果提交请求成功则为 True
  • 异常:
    • ParameterError: 如果任务配置不完整(例如 TensorFlow PB 模型缺少输出节点名)。
    • AimoTaskError: 如果任务提交失败(例如任务状态不允许提交)。
    • APIRequestError: 如果 API 请求失败。

poll_status(interval: int = 5, timeout: int = 3600) -> TaskStatus

轮询任务的执行状态,直到任务完成(成功或失败)或超时。

  • 参数:
    • interval (int, 可选): 轮询间隔时间(秒),默认为 5 秒。
    • timeout (int, 可选): 最大等待时间(秒),默认为 3600 秒 (1小时)。如果超时,将返回当前最后一次获取到的状态。
  • 返回: TaskStatus 枚举成员,表示任务的最终状态。
  • 异常:
    • APIRequestError: 如果 API 请求在轮询过程中失败。

get_result() -> ResultTaskInfo

获取已完成任务的详细结果信息。

  • 返回: ResultTaskInfo 实例,包含了任务配置、状态、消息、耗时等。
  • 异常:
    • AimoTaskError: 如果任务尚未完成或获取结果失败。
    • APIRequestError: 如果 API 请求失败。
    • 此方法会填充 task.result_infotask.task_status 属性。

download(file_mode: DownloadFileMode = DownloadFileMode.OutputModel, output_file_path: str = os.getcwd()) -> str

下载与任务相关的文件。

  • 参数:
    • file_mode (DownloadFileMode, 可选): 要下载的文件类型,默认为 DownloadFileMode.OutputModel (优化后的模型)。也可以是 DownloadFileMode.InputModel (原始输入模型)。
    • output_file_path (str, 可选): 下载文件保存的本地目录路径。默认为当前工作目录。
  • 返回: str,下载文件的完整本地路径。
  • 异常:
    • AimoTaskError: 如果无法获取下载链接或下载失败。
    • APIRequestError: 如果 API 请求失败。
    • IOError: 如果文件写入失败。

terminate() -> Dict[str, Any]

请求终止正在执行的任务。

  • 返回: API 响应数据 (dict)。
  • 异常:
    • AimoTaskError: 如果任务终止失败。
    • APIRequestError: 如果 API 请求失败。

delete() -> Dict[str, Any]

请求删除服务端上的任务。

  • 返回: API 响应数据 (dict)。
  • 异常:
    • AimoTaskError: 如果任务删除失败。
    • APIRequestError: 如果 API 请求失败。

submit_accuracy_eval() -> bool

提交精度分析任务。在优化任务成功完成后调用。

  • 返回: bool,如果提交请求成功则为 True
  • 异常:
    • AimoTaskError: 如果提交精度分析失败。
    • APIRequestError: 如果 API 请求失败。

poll_accuracy_eval_status(interval: int = 5, timeout: int = 3600) -> AccuracyEvalTaskStatus

轮询精度分析任务的状态。

  • 参数:
    • interval (int, 可选): 轮询间隔时间(秒),默认为 5 秒。
    • timeout (int, 可选): 最大等待时间(秒),默认为 3600 秒 (1小时)。
  • 返回: AccuracyEvalTaskStatus 枚举成员。
  • 异常:
    • APIRequestError: 如果 API 请求在轮询过程中失败。

get_accuracy_eval_result() -> AccuracyEvalResult

获取精度分析的结果。

  • 返回: AccuracyEvalResult 实例。
  • 异常:
    • AimoTaskError: 如果获取精度分析结果失败。
    • APIRequestError: 如果 API 请求失败。
    • 此方法会更新 task.accuracy_eval_task_status

get_info_log() -> str

获取优化任务的日志信息。

  • 返回: str,包含任务日志的字符串。
  • 异常:
    • AimoTaskError: 如果获取日志失败。
    • APIRequestError: 如果 API 请求失败。

数据类 (aplux_aimo.base_data)

SDK 使用 dataclass 来组织和传递数据。

QuantizeOptions

量化相关选项:

  • quantize_precision: ModelDataPrecision | str: 量化精度,例如 ModelDataPrecision.INT8 ("A8_W8"), ModelDataPrecision.INT16 ("A16_W8")。
  • quantize_mode: QuantizeMode | str: 量化模式,例如 QuantizeMode.Enhanced_CLE ("cle"), QuantizeMode.Enhanced_ADA ("ada")。
  • enable_per_channel_quantize: bool: 是否启用 Per-Channel 量化,默认为 False
  • calibration_data_mode: CalibrationDataMode | str: 校准数据模式,例如 CalibrationDataMode.Image ("cv"), CalibrationDataMode.Raw ("nlp"), CalibrationDataMode.Random ("random")。
  • calibration_dataset_type: CalibrationDatasetType | str: 校准数据集类型,例如 CalibrationDatasetType.ImageNet ("imagenet"), CalibrationDatasetType.COCO ("coco"), CalibrationDatasetType.Face ("face"), CalibrationDatasetType.Custom ("")。
  • calibration_data_files: list[str]: 自定义校准数据集的文件 URL 列表 (由 SDK 内部上传后填充)。
  • calibration_image_color_mode: str: 校准图像颜色模式 (当 calibration_data_mode 为 Image 时),例如 "rgb", "bgr"。默认为 "rgb"。
  • calibration_image_mean: list[float]: 校准图像均值列表,例如 [123.675, 116.28, 103.53]
  • calibration_image_std: list[float]: 校准图像标准差列表,例如 [58.395, 57.12, 57.375]

InputNode

模型输入节点信息。

  • node_name: str: 输入节点名称。
  • dimension: list[int]: 输入节点形状/维度,例如 [1, 3, 224, 224]
  • files: list[str]: 当 CalibrationDataModeRaw 时,该输入节点对应的校准数据文件 URL 列表,调用 upload_raw_calibration_data_files 方法后内部会自动上传校准数据文件到服务端,不需要用户手动上传。

InputOutputParams

模型的输入输出参数。

  • input_nodes: list[InputNode] | list[dict]: 输入节点列表,每个元素是 InputNode 实例或其字典表示。
  • output_node_name: list[str]: 输出节点名称列表。

TaskConfig

任务的完整配置。

  • task_id: str: 任务 ID。
  • description: str: 任务描述。
  • target_device: TargetDevice | str: 目标硬件设备。
  • target_device_name: str: 目标硬件设备名称 (通常由 SDK 根据 target_device 填充)。
  • target_runtime: ModelRuntime | str: 目标模型运行时/格式。
  • input_output_params: Optional[InputOutputParams]: 模型的输入输出参数。
  • quantize_options: Optional[QuantizeOptions]: 量化配置选项。
  • source_model_type: SourceModelType | str: 源模型类型。
  • source_model_file: Optional[list[str]]: (本地) 源模型文件路径列表 (主要用于创建任务时 SDK 内部记录,实际提交给服务端的是 source_model_urls)。
  • source_model_urls: list[str]: 源模型文件在服务端的 URL 列表 (由模型上传后填充)。

AccuracyEvalResult

精度分析结果:

  • id: str: 精度分析任务的 ID。
  • status: AccuracyEvalTaskStatus: 精度分析任务的状态。
  • result_json: list: 精度分析结果的 JSON 数据。
  • message: str: 精度分析相关消息。

ResultTaskInfo

任务执行完成后的详细结果信息:

  • task_id: str: 任务 ID。
  • task_config: Optional[TaskConfig]: 完成时的任务配置快照。
  • message: str: 任务执行结果的消息。
  • create_time: str: 任务创建时间戳。
  • usage_time: str: 任务执行耗时。
  • accuracy_eval_result: Optional[AccuracyEvalResult]: 精度分析结果 (如果执行了)。
  • task_status: TaskStatus | str: 任务的最终状态。
  • input_model: str: 输入模型的 URL (通常是 zip 格式)。

UpdateFileInfo

文件上传操作后的返回信息:

  • file_datas: list[dict]: 上传成功的每个文件的信息列表,包含 url 等。
  • model_info: dict: 如果上传的是模型文件,这里会包含从模型解析出的信息 (例如输入输出节点)。

枚举类型 (aplux_aimo.enums)

SDK 定义了多个枚举类型以规范参数。

SourceModelType

源模型类型:

  • ONNX = "onnx"
  • PyTorch = "pt"
  • TensorFlow_PB = "pb"
  • TensorFlow_Lite = "tflite"
  • TensorFlow_Save_Model = "pbsm"
  • Caffe = "caffe"
  • PaddlePaddle = "pd"

ModelRuntime

目标模型运行时/格式。

  • ONNX = "onnx"
  • TFLite = "tflite"
  • PaddleLite = "nb"
  • TNN = "tnn"
  • MNN = "mnn"
  • NCNN = "ncnn"
  • MindSpore = "ms"
  • RKNN_2 = "rknn"
  • SNPE_1 = "dlc_1.x"
  • SNPE_2_16 = "dlc_2.16"
  • SNPE_2_23 = "dlc_2.23"
  • SNPE_2_29 = "dlc_2.29"
  • SNPE_2_31 = "dlc_2.31"
  • QNN_2_16 = "qnn_2.16"
  • QNN_2_23 = "qnn_2.23"
  • QNN_2_29 = "qnn_2.29"
  • QNN_2_31 = "qnn_2.31"

TargetDevice

目标硬件设备。每个枚举成员包含一个 value (用于 API) 和一个 label (描述性名称)。

  • Qualcomm_QCS6490 = ("sm7325", "QCS6490")
  • Qualcomm_QCS8250 = ("8250", "QCS8250")
  • Qualcomm_QCS8550 = ("qcs8550", "QCS8550")
  • Qualcomm_Snapdragon_8_Gen1 = ("sm8450", "Snapdragon 8 Gen1")
  • Qualcomm_Snapdragon_8_Gen2 = ("sm8550", "Snapdragon 8 Gen2")
  • Qualcomm_Snapdragon_8_Gen3 = ("sm8650", "Snapdragon 8 Gen3")
  • Qualcomm_Snapdragon_8_Elite_Gen1 = ("sm8750", "Snapdragon 8 Elite")
  • Rockchip_RK3588 = ("3588", "RK3588")

使用时,可以直接用枚举成员,SDK 会自动取其 value,例如 TargetDevice.Qualcomm_QCS6490

QuantizeMode

量化模式:

  • Enhanced_ADA = "ada"
  • Enhanced_CLE = "cle"

ModelDataPrecision

模型量化数据精度:

  • INT8 = "A8_W8"
  • INT16 = "A16_W8"

CalibrationDataMode

校准数据模式 (对应 API 中的 task_type):

  • Image = "cv" (图像数据)
  • Raw = "nlp" (原始二进制数据,常用于 NLP 模型)
  • Random = "random" (使用随机生成的数据)

CalibrationDatasetType

预设的校准数据集类型或自定义。

  • ImageNet = "imagenet"
  • COCO = "coco"
  • Face = "face"
  • Custom = "" (表示使用用户自定义上传的数据集)

DownloadFileMode

下载文件类型。

  • InputModel = "input_model" (下载原始输入模型)
  • OutputModel = "output_model" (下载优化/转换后的输出模型)

TaskStatus

任务执行状态。

  • PENDING = "PENDING" (排队等待)
  • STARTED = "START" (已开始)
  • SUCCESS = "SUCCESS" (成功)
  • FAILURE = "FAILURE" (失败)
  • NOTSUBMIT = "NOTSUBMIT" (未提交)

AccuracyEvalTaskStatus

精度分析任务状态。

  • PENDING = "Waiting" (等待中)
  • STARTED = "Running" (运行中)
  • SUCCESS = "Success" (成功)
  • FAILURE = "Failure" (失败)
  • NOTSUBMIT = "Notsubmit" (未提交)

高级功能与用例

操作已有任务

如果您知道一个任务的 ID,可以使用 api.task(task_id) 来获取该任务的 AimoTask 实例,然后对其进行操作,例如查询状态、获取结果或下载文件。

python
# ... (api 初始化和登录同上) ...
known_task_id = "your_existing_task_id" # 替换为已存在的任务ID
try:
    existing_task = api.task(task_id=known_task_id)
    print(f"获取任务 {existing_task.task_id} 成功。")
    
    # 检查任务是否存在 (可选,因为 task() 本身不直接校验)
    if not existing_task.is_exist():
        print(f"任务 {existing_task.task_id} 在服务端不存在。")
    else:
        print(f"当前状态: {existing_task.poll_status(timeout=60).value}") # 短暂轮询获取最新状态
        if existing_task.task_status == TaskStatus.SUCCESS:
            result = existing_task.get_result()
            print(f"任务结果: {result.message}")
            # existing_task.download(...)
        elif existing_task.task_status == TaskStatus.FAILURE:
            print(f"任务失败,日志: {existing_task.get_info_log()}")

except Exception as e:
    print(f"操作已有任务时出错: {e}")

模型量化

模型量化通过 TaskConfig 中的 quantize_options 属性进行配置。

详细配置 QuantizeOptions

python
from aplux_aimo.base_data import QuantizeOptions
from aplux_aimo.enums import ModelDataPrecision, QuantizeMode, CalibrationDataMode, CalibrationDatasetType

# ... (task 创建同上) ...

task.config.quantize_options = QuantizeOptions(
    quantize_precision=ModelDataPrecision.INT8,       # 或 ModelDataPrecision.INT16
    quantize_mode=QuantizeMode.Enhanced_CLE,        # 或 QuantizeMode.Enhanced_ADA
    enable_per_channel_quantize=True,               # 是否启用 per-channel 量化
    calibration_data_mode=CalibrationDataMode.Image, # 校准数据类型:Image, Raw, Random
    calibration_dataset_type=CalibrationDatasetType.Custom, # 校准数据集:ImageNet, COCO, Face, Custom
    calibration_image_color_mode="rgb",             # 如果是 Image, 指定颜色模式
    calibration_image_mean=[123.675, 116.28, 103.53], # 图像均值
    calibration_image_std=[58.395, 57.12, 57.375]    # 图像标准差
)
print("量化参数已配置。")

上传图像校准数据集

如果 calibration_dataset_type 设置为 CalibrationDatasetType.Customcalibration_data_modeCalibrationDataMode.Image,则需要上传校准图片。

python
# ... (假设 task 和 quantize_options 已配置) ...
if task.config.quantize_options.calibration_dataset_type == CalibrationDatasetType.Custom and \
   task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Image:
    
    calibration_image_paths = [
        r"/path/to/your/calibration_image1.jpg",
        r"/path/to/your/calibration_image2.png",
        # ... 更多图片路径
    ]
    # 也可以是一个包含图片的文件夹路径
    # calibration_image_paths = r"/path/to/your/calibration_images_folder/"
    try:
        task.upload_calibration_data_files(files=calibration_image_paths)
        print("图像校准数据集上传成功。")
    except Exception as e:
        print(f"上传图像校准数据失败: {e}")

# 之后再 task.submit()

上传原始校准数据集

如果 calibration_data_modeCalibrationDataMode.Raw,则需要上传原始格式的校准数据。

python
# ... (假设 task 和 quantize_options 已配置) ...
if task.config.quantize_options.calibration_data_mode == CalibrationDataMode.Raw:
    # 确保 task.config.input_output_params.input_nodes 已经通过模型上传被填充
    # 或者在创建任务时通过 **kwargs 传递了 input_output_params
    
    # 示例:假设模型有两个输入节点 "input_node_A" 和 "input_node_B"
    # 并且这些节点名存在于 task.config.input_output_params.input_nodes[i].node_name
    raw_data_config = [
        {
            "input_name": "input_node_A", # 替换为实际的输入节点名
            "files": [
                r"/path/to/input_A_sample1.raw",
                r"/path/to/input_A_sample2.raw"
            ]
        },
        {
            "input_name": "input_node_B", # 替换为实际的输入节点名
            "files": [r"/path/to/input_B_sample1.raw"]
        }
    ]
    try:
        task.upload_raw_calibration_data_files(raw_calibration_data=raw_data_config)
        print("原始校准数据集上传成功。")
    except Exception as e:
        print(f"上传原始校准数据失败: {e}")
        
# 之后再 task.submit()

模型精准度分析

在优化任务成功完成后,您可以为该任务提交一个模型精准度分析子任务,了解模型每一层的精准度情况。

python
# ... (假设 task 是一个已成功完成的 AimoTask 实例) ...
if task.task_status == TaskStatus.SUCCESS:
    try:
        print("正在提交精度分析任务...")
        task.submit_accuracy_eval()
        print("精度分析任务已提交。")

        print("正在轮询精度分析状态...")
        accuracy_status = task.poll_accuracy_eval_status(interval=5, timeout=1800) # 等待最多30分钟
        print(f"精度分析最终状态: {accuracy_status.value}")

        if accuracy_status == AccuracyEvalTaskStatus.SUCCESS:
            accuracy_result = task.get_accuracy_eval_result()
            print("获取精度分析结果成功:")
            print(f"  ID: {accuracy_result.id}")
            print(f"  Message: {accuracy_result.message}")
            print(f"  Result JSON: {accuracy_result.result_json}") # 详细结果
        else:
            print(f"精度分析失败或未成功完成。状态: {accuracy_status.value}")

    except Exception as e:
        print(f"精度分析过程中出错: {e}")

获取任务日志

使用 task.get_info_log() 可以获取任务执行过程中的日志,有助于排查问题。

python
# ... (假设 task 是一个已提交或已完成的 AimoTask 实例) ...
try:
    logs = task.get_info_log()
    print("任务日志:")
    print(logs)
except Exception as e:
    print(f"获取任务日志失败: {e}")

删除任务

可以使用 task.delete() 从服务端删除一个任务。

python
# ... (假设 task 是一个 AimoTask 实例) ...
try:
    response = task.delete()
    print(f"任务 {task.task_id} 删除请求已发送。响应: {response}")
except Exception as e:
    print(f"删除任务失败: {e}")

错误处理与异常 (aplux_aimo.exceptions)

SDK 定义了一系列自定义异常来帮助开发者处理特定错误情况。所有自定义异常都继承自 AimoError

  • AimoError: 所有 AIMO SDK 异常的基类。
  • NotApiKeyError: API Key 未设置或无效时抛出。
    • message: "API Key is not set, please call login method first"
  • TaskNotExistError: 尝试操作一个在服务端不存在的任务时(或任务ID无效)。
    • message: "Task does not exist"
  • ParameterError: API 调用时参数错误或缺失。
    • message: "Parameter error" (或更具体的错误信息)
  • NetworkError:发生网络连接问题。
    • message: "Network Error"
  • APIRequestError: API 服务端返回错误状态码或非预期响应。
    • message: "API Request Error" (或更具体的错误信息)
  • FileNotExistError: 指定的本地文件路径不存在。
    • message: "File does not exist"
  • AimoTaskError: 通用的任务相关操作错误(例如状态不符、操作失败等)。
    • message: "Task Error" (或更具体的错误信息)

建议的错误处理示例:

python
from aplux_aimo import AimoApi
from aplux_aimo.exceptions import NotApiKeyError, ParameterError, APIRequestError, AimoTaskError, FileNotExistError

api = AimoApi()
try:
    api.login(api_key="YOUR_API_KEY") # 请替换
    
    task = api.new_task(
        source_model_type="onnx",
        source_model_file=r"C:\\path\\to\\your\\model.onnx", # 请替换
        target_device="sm7325",
        target_runtime="qnn_2.28"
    )
    task.submit()
    task.poll_status()
    result = task.get_result()
    print(f"任务成功: {result.message}")

except NotApiKeyError as e:
    print(f"认证失败: {e}")
except FileNotExistError as e:
    print(f"文件错误: {e}")
except ParameterError as e:
    print(f"参数配置错误: {e}")
except APIRequestError as e:
    print(f"API 请求失败: {e}")
    # 可以检查 e.args 或 e.message 获取更多服务端返回信息
except AimoTaskError as e:
    print(f"任务处理失败: {e}")
    if 'task' in locals() and task: # 如果 task 对象存在
        try:
            print(f"尝试获取任务日志: {task.get_info_log()}")
        except Exception as log_e:
            print(f"获取日志也失败了: {log_e}")
except Exception as e: # 其他未知错误
    print(f"发生未知错误: {e}")

常见问题 (FAQ)

  • Q1: 如何获取 API Key?A1: 您需要在 AIMO 官方平台 (例如:https://aimo.aidlux.com/) 注册账户,并在右上角点击用户头像然后点击用户密钥获取 API Key。

  • Q2: 支持哪些模型格式和硬件设备?A2: 请参考本文档中 SourceModelType (源模型类型) 和 TargetDevice (目标硬件设备) 小节的详细列表。

  • Q3: 量化失败如何排查?A3:

    • 检查 QuantizeOptions 配置是否正确,特别是校准模式、数据集类型、精度选择。
    • 如果使用自定义校准集,确保校准数据质量良好且格式正确,文件已成功上传。
    • 查看 task.get_info_log() 获取详细的转换日志,其中可能包含量化过程中的错误信息。
    • 确认模型本身是否适合进行所选精度的量化。
  • Q4: 为什么当SourceModelType模型为TensorFlow(saved model)时,aimo页面上传的文件时文件夹,而sdk上传的文件是zip文件?A4: 这是因为TensorFlow(saved model)模型在AIMO平台上的上传方式与SDK上传方式不同,aimo页面上传后同样是处理成zip包。

附录

SourceModelType 支持详情

下表列出了不同的源模型框架类型 (SourceModelType) ,所需文件数量,文件类型,可转换的目标推理运行时 (ModelRuntime):

源模型框架类型 (SourceModelType)值 (value)文件数量 (file_len)支持文件类型 (file_type)支持的目标运行时 (ModelRuntime)
ONNXonnx1["onnx"]TFLite(tflite), SNPE(dlc_*), QNN(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
PyTorchpt1["pt"] 注意: 模型文件需包含完整的模型结构和参数ONNX(onnx), TFLite(tflite), SNPE(dlc_*), QNN 2.23(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
TensorFlow(frozen pb)pb1["pb"]ONNX(onnx), TFLite(tflite), SNPE(dlc_*), QNN(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
TensorFlow(saved model)pbsm1["zip"]ONNX(onnx), TFLite(tflite), SNPE(dlc_*), QNN(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
TensorFlow Litetflite1["tflite"]TFLite(tflite), SNPE(dlc_*), QNN(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
Caffecaffe2[".prototxt", ".caffemodel"]ONNX(onnx), TFLite(tflite), SNPE(dlc_*), QNN(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
Paddle Paddlepd2[".pdmodel", ".pdiparams"]ONNX(onnx), TFLite(tflite), SNPE(dlc_*), QNN(qnn_*), RKNN_2(rknn), Paddle Lite(nb), TNN(tnn), MNN(mnn), NCNN(ncnn), MindSpore(ms)
MXNetMXNet2[".json", ".params"]ONNX(onnx), TFLite(tflite), Paddle Lite (nb), TNN (tnn), MNN (mnn), NCNN (ncnn), MindSpore (ms)

TargetDevice 支持详情

下表列出了不同的目标硬件设备 (TargetDevice) 及其支持的目标运行时 (ModelRuntime) 和支持的量化数据精度选项 (ModelDataPrecision)

目标设备 (label)值 (value)支持的目标运行时 (ModelRuntime)支持的量化数据精度 (ModelDataPrecision)
Snapdragon 888sm8350SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
Snapdragon 8 Gen1sm8450SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
Snapdragon 8 Gen2sm8550SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
Snapdragon 8 Gen3sm8650SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
Snapdragon 8 Elitesm8750SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
QCS6490sm7325SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
QCS8250sm8250SNPE_1_x(dlc_1.x), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8 (W8A8)
QCS8550qcs8550SNPE_2_x(dlc_2.*), QNN_2_x(qnn_2.*), TFLite(tflite), MNN(mnn), NCNN(ncnn), MindSpore(ms), PaddleLite(nb), TNN(tnn), ONNX(onnx)INT8(W8A8), INT16(W8A16)
RK3588rk3588RKNN_2(rknn), TFLite(tflite), ONNX(onnx)INT8 (W8A8)
通用 Arm CPU & GPU""TFLite(tflite), MNN(mnn),NCNN (ncnn), MindSpore(ms), Paddle Lite(nb), TNN(tnn), ONNX(onnx)