Study Apple Core AI Pipeline

项目概览

整个 Core AI 推理管线由三个项目组成:

项目 包名 职责 类比
coreai-optimization coreai-opt 模型压缩(量化/剪枝/调色板) 模型瘦身
coreai-torch coreai-torch PyTorch → Core AI IR 转换 格式翻译
coreai-core coreai IR 优化、编译、设备端执行 编译器+运行时

完整流程

┌──────────────────────────────────────────────┐
│ ① coreai-optimization (模型压缩,可选)        │
│                                              │
│  from coreai_opt.quantization import ...     │
│                                              │
│  Quantizer(model, config) → prepare()        │
│  → finalize() → 压缩后的 nn.Module           │
│                                              │
│  支持:                                       │
│  - Quantization: INT8/FP4/...                │
│  - Palettization: 代码本压缩                  │
│  - Pruning: 剪枝                             │
└────────────────┬─────────────────────────────┘
                 │ 压缩后的 nn.Module (或原始模型)
                 ▼
┌──────────────────────────────────────────────┐
│ ② coreai-torch (格式转换)                     │
│                                              │
│  from coreai_torch import TorchConverter     │
│                                              │
│  torch.export.export(model)                  │
│    → ExportedProgram (FX Graph + state)      │
│    → run_decompositions(get_decomp_table())  │
│                                              │
│  TorchConverter().add_exported_program(ep)   │
│    → to_coreai() → AIProgram (MLIR)          │
│                                              │
│  核心机制:                                    │
│  - 遍历 FX Graph,逐节点转换                  │
│  - ATen 算子 → Core AI 算子映射               │
│  - 复合算子用 @coreai.graph 包装              │
└────────────────┬─────────────────────────────┘
                 │ AIProgram (MLIR IR)
                 ▼
┌──────────────────────────────────────────────┐
│ ③ coreai-core (编译 + 执行)                   │
│                                              │
│  from coreai.authoring import AIProgram      │
│  from coreai.runtime import NDArray          │
│                                              │
│  program.optimize()     ← 图优化              │
│  program.save_asset()   ← 序列化 .aimodel     │
│  asset.executable()     ← 编译设备代码         │
│  load_function("main")  ← 加载函数            │
│  function(inputs)       ← 执行推理            │
│                                              │
│  运行后端: Metal GPU / Neural Engine / CPU    │
└──────────────────────────────────────────────┘

coreai-torch 详细流程

1. torch.export.export → ExportedProgram

ep = torch.export.export(model, args=(sample_tensor,))
ep = ep.run_decompositions(get_decomp_table())

ExportedProgram 包含:

2. TorchConverter 遍历 FX Graph

converter = TorchConverter().add_exported_program(ep)
program = converter.to_coreai()

3. 复合算子(@coreai.graph)

某些 PyTorch 算子没有直接对应的 Core AI 基础算子:

@coreai.graph(private=True, no_inline=True, composite_decl=...)
def avg_pool2d_composite(...) -> Value:
    padded = coreai.PadOp(...)
    pooled = coreai.SumPool2dOp(...)
    return coreai.broadcasting_divide(pooled, divisor)

4. optimize()

5. 产物:AIProgram(MLIR Module)

类似 FX Graph,但是编译器中间表示(MLIR 格式):

权重类型 存储方式
常规权重 dense<[[0.5, -0.3]]> 内联在 IR 文本
特殊格式(float8/sub-byte/量化) DenseResourceElementsAttr(不透明二进制 blob)

端到端代码示例

import torch
from coreai_opt.quantization import Quantizer, QuantizerConfig
from coreai_torch import TorchConverter, get_decomp_table
from coreai.runtime import NDArray, StorageKind

# === 1. 压缩 (可选) ===
model = MyModel().eval()
config = QuantizerConfig.presets.w8()  # INT8 权重量化
quantizer = Quantizer(model, config)
prepared = quantizer.prepare(example_inputs)
compressed_model = quantizer.finalize()

# === 2. 转换 ===
ep = torch.export.export(compressed_model, args=sample)
ep = ep.run_decompositions(get_decomp_table())
program = TorchConverter().add_exported_program(ep).to_coreai()

# === 3. 编译 + 执行 ===
program.optimize()
asset = program.save_asset(Path("/tmp/model.aimodel"))

async with asset.executable() as ai_model:
    function = ai_model.load_function("main")
    result = await function({
        "x": NDArray(data=input_numpy, backing=StorageKind.METAL),
    })
    output = result["output"].numpy()