ONNX export of a Random Forest
Download Python samples A Zip archive containing all samples can be found here: Samples of ONNX export |
Scikit-learn: Random Forest Classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
X, y = make_classification(n_samples=1000, n_features=3, n_informative=3, n_redundant=0, random_state=1)
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X, y)
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
# # Zipmap should be always turned off as it's not implemented in TF3800
onnx_model = convert_sklearn(clf, initial_types=initial_type, options={type(clf): {'zipmap':False}})
with open("rf_classifier.onnx", "wb") as f:
f.write( onnx_model.SerializeToString())
exit()
Scikit-learn: Random Forest Regressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
X, y = make_regression(n_samples=1000, n_features=5, n_informative=5, random_state=2)
clf = RandomForestRegressor(n_estimators=100)
clf.fit(X, y)
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
onnx_model = convert_sklearn(clf, initial_types=initial_type)
with open("rf_regressor.onnx", "wb") as f:
f.write( onnx_model.SerializeToString())
LightGBM: Random Forest Regressor
import numpy as np
import onnx_graphsurgeon as gs
import lightgbm as lgb
from lightgbm import LGBMRegressor
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from onnxmltools.convert import convert_lightgbm
import onnxmltools.convert.common.data_types
import onnx
# # Generate data for regression
X, y = make_regression(n_samples=300, n_features=10, n_informative=10, n_targets=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1)
# # Construct LightGBM-RandomForest-Model
model = LGBMRegressor(boosting_type='rf', class_weight=None, colsample_bynode=0.3, colsample_bytree=1.0, importance_type='split', learning_rate=0.05, max_depth=-1, min_child_samples=2, min_child_weight=0.001, min_split_gain=0.0, n_estimators=150, n_jobs=-1, num_class=1, num_leaves=500, objective='regression', random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True, subsample=0.632, subsample_for_bin=200000, subsample_freq=1)
model.fit(X_train, y_train,eval_set=[(X_test,y_test),(X_train,y_train)],eval_metric='rmse', verbose=20)
# # Convert model to ONNX
onnxfile = 'lgbm-regressor-randomforest.onnx'
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
onnx_model = convert_lightgbm(model, initial_types=initial_type, target_opset=12)
# Manipulate ONNX graph
# # Import model to graph object
graph = gs.import_onnx(onnx_model)
graph.name = "LGBM-RandomForest"
# # Modify TreeEnsemble output shape (necessary to meet TwinCAT requirement, working on an update to make this step obsolete)
tree_node = [node for node in graph.nodes if node.op == "TreeEnsembleRegressor"][0]
tree_node.outputs[0].shape = [None, 1]
tree_node.outputs[0].dtype = np.float32
# # Modify DIV Node inputs to provide correct averaging (necessary to correct a bug in onnxmltools version 1.11.1)
div_node = [node for node in graph.nodes if node.op == "Div"][0]
div_node.inputs[1].to_constant(values=np.asarray([[model.n_estimators]], dtype=np.float32))
# # Export graph object to ONNX ProtoModel
graph.cleanup().toposort()
onnx_model = gs.export_onnx(graph)
# # Add ONNX domain tag to TreeEnsemble Node for proper node recognition (only a reset of the tag as it gets lost during onnx manipulation)
tree_node = [node for node in onnx_model.graph.node if node.op_type == "TreeEnsembleRegressor"][0]
tree_node.domain = "ai.onnx.ml"
tree_node.doc_string = "Converted from LGBMRegressor() model with explicit shaping"
# # Export ONNX model to file
onnx.checker.check_model(onnx_model)
with open(onnxfile, "wb") as f:
f.write( onnx_model.SerializeToString())
f.close()
exit()