32 lines
999 B
Python
32 lines
999 B
Python
import onnx
|
|
import tensorflow as tf
|
|
from onnx_tf.backend import prepare
|
|
|
|
|
|
def onnx_to_tflite(onnx_model_path, tflite_model_path):
|
|
# 加载 ONNX 模型
|
|
onnx_model = onnx.load(onnx_model_path)
|
|
|
|
# 使用 onnx-tf 将 ONNX 模型转换为 TensorFlow 模型
|
|
tf_rep = prepare(onnx_model)
|
|
|
|
# 保存 TensorFlow 模型
|
|
saved_model_dir = 'runs/detect/train/weights/best.pt'
|
|
tf_rep.export_graph(saved_model_dir)
|
|
|
|
# 加载 TensorFlow 模型并转换为 TFLite 模型
|
|
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
|
tflite_model = converter.convert()
|
|
|
|
# 保存 TFLite 模型
|
|
with open(tflite_model_path, 'wb') as f:
|
|
f.write(tflite_model)
|
|
print(f"模型已成功转换并保存为 {tflite_model_path}")
|
|
|
|
|
|
# 使用示例
|
|
onnx_model_path = 'model.onnx' # 替换为你的 ONNX 模型文件路径
|
|
tflite_model_path = 'model.tflite' # 替换为你希望保存的 TFLite 模型路径
|
|
|
|
onnx_to_tflite(onnx_model_path, tflite_model_path)
|