scala导出模型为pmml文件时报错:
Caused by: java.lang.IllegalArgumentException: Expected org.apache.spark.ml.Transformer subclass, got org.apache.spark.ml.feature.OneHotEncoder
【报错位置】:


【全部代码】:
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.PMMLBuilder
import java.io.{File, FileOutputStream}
import javax.xml.transform.stream.StreamResult
object SparkMLLogicalRegressionPipeLine {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
System.err.println("需要输入三个参数: <murl> <inputfile> <modelpath>")
System.exit(1)
}
val murl = args(0) //运行内核数:local[2]
val inputfile = args(1) //训练数据的导入路径:/logicRegression/Iris/model/data/iris.csv
val modelpath = args(2) //模型的保存HDFS路径:/logicRegression/Iris/model/logit_pipeline.pmml
val spark = SparkSession.builder().appName(s"${this.getClass.getName}").master(murl).getOrCreate()
// 加载数据,生成DataFrame
val df = spark.read.option("header", true)
.csv(inputfile)//读取第一行为标签的csv数据,header表示第一行不读取
df.show(3, false)
df.printSchema()
//选取需要训练的特征字段
val dfDouble = df.select(col("sepal_length").cast(DoubleType), col("sepal_width").cast(DoubleType),
col("petal_length").cast(DoubleType), col("petal_width").cast(DoubleType),
col("species").alias("label")
)
dfDouble.printSchema()
// 特征处理,生成特征向量
// 创建一个VectorAssembler实例,用于将多列特征组合成单一的特征向量
val assembler = new VectorAssembler().setInputCols(
Array("sepal_length", "sepal_width", "petal_length", "petal_width")
).setOutputCol("features")
// 使用VectorAssembler转换原始DataFrame,生成一个新的DataFrame,其中包含特征向量和标签列
val dataFrame = assembler.transform(dfDouble).select("features", "label")
// 显示转换后的DataFrame的前3行数据,以验证转换结果
dataFrame.show(3, 0)
// 获取标签列和特征列
// 使用StringIndexer将标签列转换为索引形式,以便后续的机器学习算法能够处理
val labelIndex = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(dataFrame)
// 使用VectorIndexer对特征列进行索引,这有助于提高机器学习模型的效率和效果
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(dataFrame)
// 创建Logistic回归模型实例,设置标签列和特征列,以及模型的训练参数
// 最大迭代次数设为100,正则化参数设为0.3,ElasticNet参数设为0.8,这样的设置旨在平衡偏差和方差,避免过拟合
// 调用Logistic回归模型
val logisticRegression = new LogisticRegression().setLabelCol("labelIndex").setFeaturesCol("indexedFeatures")
.setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8)
println("logistricRegression parameters:\n" + logisticRegression.explainParams() + "\n")
// 设置indexToString转换器
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel")
.setLabels(labelIndex.labels)
// 设置逻辑回归流水线
val lrpiple = new Pipeline().setStages(Array(labelIndex, featureIndexer, logisticRegression, labelConverter))
// 划分训练集和测试集,利用随机种子
val Array(trainingData, testData) = dataFrame.randomSplit(Array(0.7, 0.3), 1234L)
trainingData.show(5, 0)
testData.show(5, 0)
// 利用流水线训练模型
val model = lrpiple.fit(trainingData)
println("success fit......")
val predictions = model.transform(testData)
// 显示预测结果
predictions.select("predictedLabel", "label", "features", "probability").show(5, 0)
// 评估模型
// 创建一个MulticlassClassificationEvaluator实例用于评估分类模型的准确性
// 设置评估器的标签列名为"labelIndex",预测列名为"prediction",并使用"accuracy"作为评估指标
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("labelIndex").setPredictionCol("prediction")
.setMetricName("accuracy")
// 使用评估器计算预测结果的准确性
val accuracy = evaluator.evaluate(predictions)
// 打印测试错误率,即1减去准确率
println("Test Error = " + (1.0 - accuracy))
// 通过流水线获取模型参数
val lrModel = model.stages(2).asInstanceOf[LogisticRegressionModel]
println("Learned classification logistic regression model:\n" + lrModel.summary.totalIterations)
println("Coefficients: \n" + lrModel.coefficientMatrix)
println("Intercepts: \n" + lrModel.interceptVector)
println("logistic regression model num of Classes" + lrModel.numClasses)
println("logistic regression model num of features" + lrModel.numFeatures)
// 保存模型
// lrModel.write.overwrite().save(modelpath)
//保存模型为PMML格式
val pmml = new PMMLBuilder(df.schema, model).build()
// val targetFile = "...\\scalaProgram\\PMML\\pipemodel.pmml"
val fis: FileOutputStream = new FileOutputStream(modelpath)
val fout: StreamResult = new StreamResult(fis)
JAXBUtil.marshalPMML(pmml, fout)
println("pmml success......")
spark.stop()
}
}