PySpark ML实战:工业级机器学习流水线构建指南

发布时间:2026/6/15 7:19:04

PySpark ML实战:工业级机器学习流水线构建指南 1. 为什么在PySpark上做机器学习不是“大材小用”而是“不得不选”你手头有2TB的用户行为日志每天新增3亿条点击、曝光、加购、下单记录你训练一个推荐模型特征维度超过5000样本量突破10亿你刚用scikit-learn跑完一个逻辑回归本地内存直接爆掉Jupyter Kernel死得悄无声息——这时候别急着换服务器更别想着把数据采样到1%再建模。PySpark不是Python的分布式玩具它是处理真实工业级机器学习流水线的底层操作系统。我在电商风控团队实操过三年从单机XGBoost迁移到PySpark ML Pipeline核心关键词就三个可扩展性、一致性、可运维性。可扩展性意味着模型训练不再卡在单机8核32G的天花板上一致性指特征工程、模型训练、在线预测的代码逻辑完全复用避免“线下训练A版本线上服务B版本”的经典翻车可运维性则是说整个流程能被Airflow调度、被Prometheus监控、被YARN资源管理器统一纳管。这不是技术炫技而是当你的AB测试要同时跑12个策略、每个策略需分钟级更新特征、模型需小时级重训时唯一能扛住的架构选择。它不替代scikit-learn但当你面对的数据规模超出单机内存两倍以上或业务要求模型迭代周期压缩到小时级PySpark ML就是那个你绕不开的“生产环境入场券”。适合谁不是刚学完pandas的新人而是已经用sklearn调过参、写过Pipeline、被OOM报错教育过的中级以上数据工程师和算法工程师——你缺的不是理论是把模型真正跑进生产环境的能力。2. 整体设计思路为什么不用MLlib原生API而坚持用DataFrame Estimator/Transformer范式2.1 架构选型背后的三重现实约束很多人第一次接触PySpark ML会本能地去查pyspark.mllibRDD-based和pyspark.mlDataFrame-based的区别然后被文档里那句“ml is the newer, recommended library”带偏。但真实决策从来不是看文档推荐而是看三件事数据形态是否匹配、团队协作是否顺畅、上线路径是否清晰。我带过的两个项目给出了截然不同的答案第一个是离线特征平台重构原始数据是Hive表Kafka实时流全部以DataFrame形式存在第二个是老系统迁移遗留代码全是RDD操作强行转DataFrame导致特征计算逻辑重写量超70%。最终我们全量采用pyspark.ml理由非常务实数据形态零摩擦Hive、Delta Lake、Parquet文件天然读成DataFrame无需额外转换。你用spark.read.parquet(s3://bucket/features/)加载的数据直接就能喂给StringIndexer或VectorAssembler中间没有.rdd.map()这种破坏性转换。而mllib要求输入必须是LabeledPoint或RDD[Vector]这意味着每次读数据都要多一层map(lambda row: LabeledPoint(row.label, row.features))不仅性能损耗序列化开销增加15%-20%更致命的是类型安全丢失——编译期无法检查label字段是否存在运行时报AttributeError: Row object has no attribute label才是常态。团队协作成本断崖下降数据工程师写ETL产出宽表算法工程师在此基础上加特征、训模型、存结果所有人操作的都是同一份DataFrame Schema。我们曾让新来的算法同学直接修改一个已有的PipelineModel他只改了两行StringIndexer.setInputCol(user_id)其他步骤完全不动模型就重新训练上线了。如果是RDD API他得先理解mapPartitions里如何分片、如何广播大字典、如何聚合统计量协作门槛直接翻倍。上线路径极度收敛PipelineModel.save()生成的目录结构与spark.read.load()加载的格式完全一致且支持跨Spark版本兼容我们从3.1.2升级到3.3.0时所有已保存模型无需重训。而mllib的模型保存是二进制格式升级Spark后大概率报java.lang.ClassNotFoundException。更关键的是PipelineModel能无缝接入Structured Streaming你把离线训练好的PipelineModel直接transform()实时流特征工程和预测一步到位这是mllib永远做不到的。提示不要被“RDD更底层、更灵活”的说法迷惑。在机器学习场景中“灵活”往往意味着“不可控”。我们做过压测同样一个TF-IDF特征工程DataFrame API比手动RDD实现快1.8倍因为Catalyst优化器能自动合并select().filter().withColumn()链式操作而RDD的map().filter().map()每步都触发Shuffle。2.2 为什么拒绝纯SQL方案坚持Python主导的Pipeline也有团队尝试用纯SQL写特征工程比如CREATE TABLE features AS SELECT user_id, COUNT(*) as click_cnt FROM logs GROUP BY user_id。短期看很爽但很快会撞墙第一复杂特征如“过去7天用户点击品类的熵值”需要UDF用户自定义函数而SQL UDF性能极差JVM序列化开销大第二模型超参调优无法嵌入SQL你总不能写SELECT * FROM grid_search_result WHERE rmse (SELECT MIN(rmse) FROM grid_search_result)吧第三调试成本爆炸——SQL报错只告诉你“AnalysisException”而Python异常栈能精准定位到VectorAssembler.setInputCols([age, income])里少写了gender字段。我们最终定下铁律SQL只用于数据接入和简单聚合所有带逻辑的特征、所有模型训练、所有评估必须用PySpark ML API。这保证了整条链路的可观测性和可调试性。2.3 核心组件选型逻辑不是功能越多越好而是“够用稳定”PySpark ML提供了几十个Estimator和Transformer但实际高频使用的不超过10个。我们的选型清单基于三年线上事故总结特征编码StringIndexer必须配合IndexToString反向映射否则线上预测时类别未知会报错、OneHotEncoder注意dropLastTrue防共线性、StandardScaler必须fit()在训练集transform()在测试集这点和sklearn完全一致特征组合VectorAssembler核心所有数值/编码特征必须组装成features列这是ML算法的唯一输入格式、ChiSqSelector卡方检验选特征比方差阈值更鲁棒模型算法LogisticRegression风控/推荐首选支持L1/L2正则、权重列处理样本不均衡、RandomForestClassifier树模型抗噪强但numTrees超过100后收益递减、GBTClassifier精度更高但训练慢3倍慎用于小时级任务评估与调优BinaryClassificationEvaluator必须指定metricNameareaUnderROC别用默认的f1风控场景AUC才是金标准、TrainValidationSplit比CrossValidator快5倍适合大数据量这个清单背后是血泪教训曾因误用NormalizerL2归一化替代StandardScaler导致特征量纲混乱AUC暴跌0.15也曾因CrossValidator的numFolds5在10亿样本上跑满24小时最后紧急切到TrainValidationSplit才保住SLA。3. 核心细节解析从数据接入到模型保存的12个关键实操节点3.1 数据接入Parquet分区与Schema演化如何影响特征一致性PySpark读取数据的第一步看似简单实则埋雷最多。我们线上最常踩的坑是离线训练用的Parquet表字段顺序和类型与实时流接入的Schema不一致导致VectorAssembler组装失败。根本原因在于Parquet的Schema演化机制——当你用df.write.mode(overwrite).partitionBy(dt).parquet(path)写入时Spark会为每个分区生成独立的Schema文件如果某天上游ETL漏写了一个字段该分区的Schema就少了这一列而spark.read.parquet()默认采用“union schema”合并所有分区Schema缺失字段会变成null但类型可能推断错误比如本该是DoubleType却成了StringType。解决方案不是禁止Schema演化而是主动控制# 正确做法显式定义Schema强制类型对齐 from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType feature_schema StructType([ StructField(user_id, StringType(), False), StructField(item_id, StringType(), False), StructField(click_cnt_7d, DoubleType(), True), # 显式设为可空 StructField(ctr_smoothed, DoubleType(), True), StructField(dt, StringType(), False) ]) # 读取时强制应用Schema避免类型推断偏差 df spark.read.schema(feature_schema).parquet(s3://bucket/features/)注意mode(overwrite)写Parquet时务必加上.option(mergeSchema, true)否则新分区Schema变更不会合并到元数据后续读取仍用旧Schema。我们吃过亏一次ETL新增is_premium_user布尔字段但没开mergeSchema导致新分区数据该字段全为null模型效果波动持续两天才定位到。3.2 特征工程StringIndexer的“未见过类别”陷阱与平滑策略StringIndexer是处理类别型特征的标配但它的默认行为在生产环境极其危险当训练集没见过某个类别比如新上架的商品IDtransform()时直接抛IllegalArgumentException: Unseen label。这在实时预测中等于服务雪崩。解决方案不是关掉handleInvalid参数而是用两级索引平滑映射# 第一级用训练集统计频次过滤低频类别100次归为UNK from pyspark.sql.functions import col, when, count, lit # 统计user_id频次 user_freq df.groupBy(user_id).count().filter(col(count) 100).select(user_id) df_filtered df.join(user_freq, onuser_id, howinner) # 第二级StringIndexer只在过滤后的数据上fit确保所有类别都见过 indexer StringIndexer(inputColuser_id, outputColuser_id_idx, handleInvalidkeep) indexer_model indexer.fit(df_filtered) # 关键handleInvalidkeep会把未见类别映射到-1.0但我们需要0.0方便后续OneHot # 所以手动修正 from pyspark.sql.functions import when, col df_indexed indexer_model.transform(df).withColumn( user_id_idx, when(col(user_id_idx) -1.0, 0.0).otherwise(col(user_id_idx)) )这个方案比单纯设handleInvalidkeep更可控因为-1.0在OneHot编码中会生成无效向量而0.0可以明确对应UNK槽位。我们还给所有StringIndexer加了setMinSupport(100)需自定义继承彻底杜绝长尾噪声。3.3 特征组装VectorAssembler的稀疏向量优化与内存规避VectorAssembler把多个列拼成features向量但默认输出是稠密向量DenseVector当特征维度超5000时单行内存占用飙升。我们一个推荐模型有6231维特征用稠密向量时Executor内存溢出率高达35%。解决方案是强制生成稀疏向量SparseVector# 关键设置inputCols时只传非零特征列其余用0填充 # 先统计每列的非零率 nonzero_stats df.select( [(count(when(col(c) ! 0, c)) / count(*)).alias(f{c}_nnz) for c in numeric_cols] ).collect()[0] # 筛选非零率5%的列作为主特征其余归入low_freq_features main_cols [c for c in numeric_cols if nonzero_stats[f{c}_nnz] 0.05] low_freq_cols [c for c in numeric_cols if c not in main_cols] # 主特征用VectorAssembler低频特征单独处理 assembler VectorAssembler( inputColsmain_cols, outputColmain_features, handleInvalidkeep ) # 低频特征用SparseVector手动构建省内存 from pyspark.ml.linalg import Vectors def build_sparse_vector(row): # 只存储非零值的索引和值 indices [] values [] for i, col_name in enumerate(low_freq_cols): val getattr(row, col_name) if val ! 0: indices.append(i) values.append(float(val)) return Vectors.sparse(len(low_freq_cols), indices, values) # 注册UDF注意仅用于低频特征避免全量UDF拖慢性能 build_sparse_udf udf(build_sparse_vector, VectorUDT()) df_with_sparse df.withColumn(low_freq_features, build_sparse_udf(struct(*low_freq_cols)))实测下来稀疏向量使单行内存降低68%Shuffle数据量减少42%。代价是代码稍复杂但换来的是稳定的小时级任务SLA。3.4 模型训练LogisticRegression的权重列与样本不均衡破解风控场景正负样本比常达1:1000直接训练LR模型AUC虚高但KS值惨不忍睹。LogisticRegression的weightCol参数是解药但用法极易出错。常见错误是直接用F.when(F.col(label)1, 1000).otherwise(1)生成权重这会导致梯度爆炸正样本权重过大模型只认正样本。正确做法是按类别频率倒数加权并做Z-score归一化# 计算各类别频率 label_stats df.select(label).groupBy(label).count().collect() total_count sum([row[count] for row in label_stats]) label_weights {row[label]: total_count / row[count] for row in label_stats} # 应用权重并归一化避免绝对值过大 from pyspark.sql.functions import col, when, sqrt, mean max_weight max(label_weights.values()) df_weighted df.withColumn( weight, when(col(label) 0, label_weights[0] / max_weight) .otherwise(label_weights[1] / max_weight) ) # 归一化到[0.1, 10]区间防止梯度失稳 df_weighted df_weighted.withColumn( weight, when(col(weight) 0.1, 0.1) .when(col(weight) 10, 10) .otherwise(col(weight)) )这个权重策略让我们在欺诈识别任务中KS值从0.32提升到0.58且模型收敛速度加快2.3倍迭代次数减少。3.5 模型评估AUC计算的分桶陷阱与校准曲线绘制BinaryClassificationEvaluator计算AUC默认使用numBins1000但在10亿样本上分桶数过少会导致AUC估值偏差±0.005。我们通过实测发现numBins10000时AUC稳定但内存占用翻倍。折中方案是分层抽样计算# 对预测概率分10层按prob等距切分每层抽1%样本计算局部AUC再加权平均 from pyspark.sql.functions import col, when, floor, count, lit from pyspark.sql.window import Window # 添加分层标签 df_prob model.transform(df_test).withColumn( prob_bin, floor(col(probability).getItem(1) * 10).cast(int) # 按正类概率分10层 ) # 每层抽1%样本用sampleBy fractions {i: 0.01 for i in range(10)} df_sampled df_prob.sampleBy(prob_bin, fractions, seed42) # 在抽样集上计算AUC evaluator BinaryClassificationEvaluator( labelCollabel, rawPredictionColrawPrediction, metricNameareaUnderROC, numBins1000 # 抽样后数据量小用默认分桶即可 ) auc evaluator.evaluate(df_sampled)此外我们必做校准曲线Calibration Curve验证概率准确性用plt.hist()画预测概率分布用sklearn.calibration.calibration_curve计算实际正例率。曾发现模型在prob0.9区间实际正例率仅65%说明高置信度预测严重过拟合立即引入calibrator PlattCalibrator()自定义Platt缩放修复。3.6 模型保存与加载PipelineModel的版本管理与跨集群兼容PipelineModel.save()生成的目录包含stages/各模型参数、metadata/Spark版本、时间戳、uid唯一标识。但线上最痛的点是不同Spark集群版本间模型不兼容。Spark 3.2.0保存的模型在3.3.0上load()会报java.lang.NoSuchMethodError。解决方案是强制统一模型序列化协议# 保存时指定兼容模式 pipeline_model.write().option(compression, snappy).save(model_path) # 加载时用SparkSession的配置绕过版本检查需谨慎 spark.conf.set(spark.sql.adaptive.enabled, false) # 关闭AQE避免执行计划差异 spark.conf.set(spark.sql.adaptive.coalescePartitions.enabled, false) # 加载后验证UID一致性防误加载 loaded_model PipelineModel.load(model_path) assert loaded_model.uid expected_uid, fModel UID mismatch: {loaded_model.uid} vs {expected_uid}我们还建立了模型仓库每个model_path按{project}/{date}/{version}组织如s3://models/recommender/20240520/v1/并通过Delta Lake记录元数据AUC、KS、训练耗时、特征列表实现模型可追溯。3.7 超参调优TrainValidationSplit的并行度控制与早停机制TrainValidationSplit比CrossValidator快但默认parallelism1即串行训练所有参数组合。我们通过setParallelism(4)开启并行但立刻遇到Driver内存溢出——因为每个子任务返回完整PipelineModelDriver需缓存所有模型。解决方案是只返回评估指标模型按需加载# 自定义评估器只返回metric值不返回模型 class LightEvaluator(BinaryClassificationEvaluator): def _evaluate(self, dataset): # 只计算AUC不保存模型 return super()._evaluate(dataset) # TrainValidationSplit中指定 tvs TrainValidationSplit( estimatorpipeline, evaluatorLightEvaluator(labelCollabel, metricNameareaUnderROC), trainRatio0.8, parallelism4 # 设为Executor核心数的1/2避免资源争抢 ) # 训练后只保留最优参数组合再单独训一次完整模型 best_model tvs.fit(df_train) # best_model.bestModel 是完整PipelineModel可直接save()早停机制则通过监控验证集AUC若连续3轮AUC提升0.001则中断训练。我们用spark.sparkContext.setLocalProperty(spark.scheduler.pool, high_priority)为早停任务分配高优先级队列确保及时响应。3.8 特征重要性RandomForest的split信息提取与业务可解释性RandomForestClassifier的featureImportances是向量但业务方要的是“为什么这个用户被判高风险”。我们通过解析stage中的trees获取分裂规则# 获取第一棵树的分裂节点 rf_model best_model.stages[-1] # 假设RF是Pipeline最后一步 tree rf_model.trees[0] # 递归遍历树节点提取分裂条件 def extract_splits(node, depth0): if node.isLeaf(): return [] else: # 获取分裂特征名需映射回原始列名 feature_idx node.split.featureIndex feature_name feature_names[feature_idx] # feature_names来自VectorAssembler.getInputCols() threshold node.split.threshold return [f{ *depth}{feature_name} {threshold:.3f}] \ extract_splits(node.leftChild, depth1) \ extract_splits(node.rightChild, depth1) splits extract_splits(tree.rootNode)将这些规则转化为业务语言“用户近7天点击品类熵值≤1.2且历史欺诈率≥5%”直接嵌入风控报告大幅提升业务信任度。3.9 实时预测Structured Streaming集成与状态管理离线模型要服务实时流Structured Streaming是唯一选择。但transform()不能直接用因为PipelineModel不是StreamingQuery。正确姿势是用foreachBatchdef process_batch(batch_df, batch_id): # 对每个微批数据应用模型 if batch_df.count() 0: # 防空批 result_df pipeline_model.transform(batch_df) # 写入Kafka或Redis result_df.select(user_id, prediction, probability).write \ .format(kafka) \ .option(kafka.bootstrap.servers, kafka:9092) \ .option(topic, model_predictions) \ .save() # 启动流式查询 query streaming_df.writeStream \ .foreachBatch(process_batch) \ .outputMode(Append) \ .trigger(processingTime60 seconds) \ .start()关键点foreachBatch内必须transform()不能fit()会重复训练且需加if batch_df.count() 0判断否则空批触发transform()会报NullPointerException。我们还为每个batch加了batch_id水印用于下游去重。3.10 监控告警特征漂移检测与模型衰减预警模型上线后最大的敌人不是代码bug而是数据漂移。我们每小时用chi2_contingency检测类别特征分布变化用KS-test检测数值特征# 计算当前批次与基线分布的KS统计量 from scipy.stats import ks_2samp def detect_drift(current_series, baseline_series, threshold0.05): ks_stat, p_value ks_2samp(current_series, baseline_series) return ks_stat threshold and p_value 0.05 # 在Spark中实现UDF包装 drift_udf udf(detect_drift, BooleanType()) df_drift df_current.agg( drift_udf(col(click_cnt_7d), lit(baseline_clicks)).alias(click_cnt_drift) )一旦漂移告警自动触发模型重训Pipeline并邮件通知算法负责人。这套机制让我们在一次上游ETL逻辑变更将click_cnt从“去重点击”改为“原始点击”中提前2小时发现特征异常避免了线上误判。3.11 资源调优Executor内存与GC参数的黄金配比PySpark ML最耗资源的是VectorAssembler和RandomForest。我们通过YARN日志分析总结出Executor内存分配公式executor_memory (feature_dim × 8 × 1.5) (sample_count × 0.02) 4096 MB其中feature_dim × 8是向量存储每个double占8字节×1.5是Shuffle缓冲区sample_count × 0.02是样本索引内存每样本约0.02MB4096是JVM基础开销。例如6231维、1亿样本6231×8×1.5≈75MB1e8×0.022000MB总内存≈6GB故设--executor-memory 8g。GC参数必加--conf spark.executor.extraJavaOptions-XX:UseG1GC -XX:MaxGCPauseMillis200实测GC时间减少65%。3.12 错误日志解析从Stack Trace定位到具体TransformerPySpark报错堆栈常长达200行关键信息被淹没。我们建立了一套日志解析规则org.apache.spark.SparkException: Failed to execute user defined function→ 查UDF注册位置java.lang.ArrayIndexOutOfBoundsException: 123→ 定位VectorAssembler的第123个inputCol检查该列是否存在nullorg.apache.spark.sql.catalyst.analysis.UnresolvedException→ 检查StringIndexer的inputCol是否在DataFrame中最有效的方法是在Pipeline每个Stage加日志前缀# 自定义LoggingTransformer class LoggingTransformer(Transformer): def __init__(self, stage_name): super().__init__() self.stage_name stage_name def _transform(self, dataset): print(f[{self.stage_name}] Start transform, rows: {dataset.count()}) return dataset # 插入Pipeline pipeline Pipeline(stages[ LoggingTransformer(LOAD_DATA), StringIndexer(...), LoggingTransformer(INDEX_USER_ID), VectorAssembler(...) ])这样报错时一眼看到[INDEX_USER_ID]阶段失败直接聚焦问题。4. 实操过程全记录从零搭建电商用户流失预测Pipeline4.1 环境准备与依赖安装我们使用Spark 3.3.0Scala 2.12Python 3.9关键依赖版本锁定# requirements.txt pyspark3.3.0 pandas1.5.3 scikit-learn1.2.2 scipy1.10.1 # 注意不要装最新版Spark 3.3.0与scikit-learn 1.3.0有兼容问题集群配置YARN ResourceManager 5个NodeManager每个Executor 8核16GB内存Driver 4核8GB。特别注意PySpark必须与集群Spark版本严格一致我们曾因本地pip install pyspark3.4.0连接3.3.0集群报java.lang.NoSuchMethodError: org.apache.spark.sql.catalyst.expressions.Alias降级后解决。4.2 数据准备模拟10亿用户行为日志真实数据涉及隐私我们用合成数据演示。核心表结构字段名类型说明user_idstring用户IDMD5哈希dtstring分区日期YYYYMMDDclick_cnt_1ddouble当日点击数click_cnt_7ddouble近7天点击数order_cnt_30ddouble近30天下单数avg_order_amt_30ddouble近30天均单金额last_login_daysint距上次登录天数is_premiumboolean是否付费会员生成脚本要点user_id用uuid.uuid4().hex[:16]生成避免MD5碰撞click_cnt_1d服从泊松分布λ5order_cnt_30d服从负二项分布模拟长尾last_login_days用np.random.choice([0,1,3,7,15,30], p[0.4,0.2,0.15,0.1,0.1,0.05])模拟用户活跃度衰减# 生成1亿样本10%数据量便于本地调试 from pyspark.sql import SparkSession import numpy as np spark SparkSession.builder \ .appName(churn-simulate) \ .config(spark.sql.adaptive.enabled, false) \ .getOrCreate() # 生成随机数据 np.random.seed(42) n_samples 100000000 data [ ( fuser_{i % 1000000:06d}, # 模拟100万用户 20240520, np.random.poisson(5), # click_cnt_1d np.random.poisson(35), # click_cnt_7d np.random.negative_binomial(2, 0.3), # order_cnt_30d round(np.random.lognormal(8, 0.5), 2), # avg_order_amt_30d np.random.choice([0,1,3,7,15,30], p[0.4,0.2,0.15,0.1,0.1,0.05]), bool(np.random.binomial(1, 0.15)) ) for i in range(n_samples) ] df spark.createDataFrame(data, [user_id,dt,click_cnt_1d,click_cnt_7d,order_cnt_30d,avg_order_amt_30d,last_login_days,is_premium]) df.write.mode(overwrite).partitionBy(dt).parquet(hdfs://namenode:8020/data/churn_simulate/)4.3 特征工程构建12维流失预测特征流失定义last_login_days 30且order_cnt_30d 0。特征设计遵循“行为-价值-状态”三层行为层click_cnt_1d,click_cnt_7d,click_ratio_7d7天点击/总点击价值层avg_order_amt_30d,order_cnt_30d,revenue_30d订单数×均单金额状态层is_premium,last_login_days,login_freq_7d7天登录天数/7关键代码from pyspark.sql.functions import col, when, lit, coalesce, log, sqrt from pyspark.sql.window import Window # 计算衍生特征 df_features df \ .withColumn(click_ratio_7d, when(col(click_cnt_7d) 0, col(click_cnt_1d) / col(click_cnt_7d)).otherwise(0.0)) \ .withColumn(revenue_30d, col(order_cnt_30d) * col(avg_order_amt_30d)) \ .withColumn(login_freq_7d, when(col(last_login_days) 7, 1.0).otherwise(0.0)) \ .withColumn(churn_label, when((col(last_login_days) 30) (col(order_cnt_30d) 0), 1).otherwise(0)) # 处理无穷大log(0) df_features df_features \ .withColumn(log_revenue, when(col(revenue_30d) 0, log(col(revenue_30d))).otherwise(0.0)) \ .withColumn(sqrt_click, sqrt(col(click_cnt_7d))) # 选择12个特征列 feature_cols [ click_cnt_1d, click_cnt_7d, click_ratio_7d, avg_order_amt_30d, order_cnt_30d, revenue_30d, log_revenue, sqrt_click, is_premium, last_login_days, login_freq_7d, churn_label ] df_final df_features.select(feature_cols) df_final.cache() # 必须cache避免后续多次计算4.4 模型训练LogisticRegression全流程实现from pyspark.ml import Pipeline from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder # 1. 划分训练测试集按user_id哈希保证用户不泄露 train_df, test_df df_final.randomSplit([0.8, 0.2], seed42) train_df train_df.filter(col(churn_label).isNotNull()) test_df test_df.filter(col(churn_label).isNotNull()) # 2. 构建Pipeline # 特征缩放StandardScaler必须fit在训练集 scaler StandardScaler(inputColfeatures, outputColscaled_features, withStdTrue, withMeanTrue) # 组装特征向量 assembler VectorAssembler( inputCols[c for c in feature_cols if c ! churn_label], outputColfeatures, handleInvalidkeep ) # 逻辑回归 lr LogisticRegression( featuresColscaled_features, labelColchurn_label, predictionColprediction, probabilityColprobability, rawPredictionColrawPrediction, regParam0.01, # L2正则 elasticNetParam0.0, # 纯L2 maxIter100, weightColweight # 后续添加权重 ) # Pipeline串联 pipeline Pipeline(stages[assembler, scaler, lr]) # 3. 添加样本权重解决不均衡 label_counts train_df.groupBy(churn_label).count().collect() pos_count [r[count] for r in label_counts if r[churn_label] 1][0]

相关新闻