ONNX export of Hist Gradient Boosting
Download Python samples A Zip archive containing all samples can be found here: Samples of ONNX export |
Hist Gradient Boosting Regressor with Scikit-learn
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.datasets import make_regression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# import onnx
# # Generate data for regression
X, y = make_regression(n_samples=300, n_features=10, n_informative=10, n_targets=1)
# # Construct the model
model = HistGradientBoostingRegressor(learning_rate=0.08, max_depth=3, max_iter=300)
model.fit(X,y)
# # Convert model to ONNX
onnxfile = 'histgdb-regressor.onnx'
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
onnx_model = convert_sklearn(model, initial_types=initial_type, target_opset=12)
# # Export to ONNX file
# onnx.checker.check_model(onnx_model)
with open(onnxfile, "wb") as f:
f.write( onnx_model.SerializeToString())
f.close()
exit()
Hist Gradient Boosting Classifier with Scikit-learn
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# import onnx
# # Load data for classification
X, y = load_iris(return_X_y = True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1)
# # Construct the model
model = HistGradientBoostingClassifier(learning_rate=0.08, max_depth=3, max_iter=300)
model.fit(X_train,y_train)
# # Convert model to ONNX
onnxfile = 'histgdb-iris.onnx'
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
# # Zipmap should be always turned off as it's not implemented in TF3800
onnx_model = convert_sklearn(model, initial_types=initial_type, options={type(model): {'zipmap':False}}, target_opset=12)
# # Export to ONNX file
# onnx.checker.check_model(onnx_model)
with open(onnxfile, "wb") as f:
f.write( onnx_model.SerializeToString())
f.close()
exit()