在機(jī)器學(xué)習(xí)中,一般都會(huì)按照下面幾個(gè)步驟:特征提取、數(shù)據(jù)預(yù)處理、特征選擇、模型訓(xùn)練、檢驗(yàn)優(yōu)化。那么特征的選擇就很關(guān)鍵了,一般模型最后效果的好壞往往都是跟特征的選擇有關(guān)系的,因?yàn)槟P捅旧淼膮?shù)并沒(méi)有太多優(yōu)化的點(diǎn),反而特征這邊有時(shí)候多加一個(gè)或者少加一個(gè),最終的結(jié)果都會(huì)差別很大。

在SparkMLlib中為我們提供了幾種特征選擇的方法,分別是VectorSlicer、RFormulaChiSqSelector。

下面就介紹下這三個(gè)方法的使用,強(qiáng)烈推薦有時(shí)間的把參考的文獻(xiàn)都閱讀下,會(huì)有所收獲!

VectorSlicer

這個(gè)轉(zhuǎn)換器可以支持用戶自定義選擇列,可以基于下標(biāo)索引,也可以基于列名。

  • 如果是下標(biāo)都可以使用setIndices方法

  • 如果是列名可以使用setNames方法。使用這個(gè)方法的時(shí)候,vector字段需要通過(guò)AttributeGroup設(shè)置每個(gè)向量元素的列名。

注意1:可以同時(shí)使用setInices和setName

object VectorSlicer {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("VectorSlicer-Test").setMaster("local[2]")    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")    var sqlContext = new SQLContext(sc)    val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0,1.0,2.0)))    val defaultAttr = NumericAttribute.defaultAttr    val attrs = Array("f1", "f2", "f3","f4","f5").map(defaultAttr.withName)    val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]])    val dataRDD = sc.parallelize(data)    val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField())))    val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features")

    slicer.setIndices(Array(0)).setNames(Array("f2"))    val output = slicer.transform(dataset)
    println(output.select("userFeatures", "features").first())
  }
}

注意2:如果下標(biāo)和索引重復(fù),會(huì)報(bào)重復(fù)的錯(cuò):

比如:

slicer.setIndices(Array(1)).setNames(Array("f2"))

那么會(huì)遇到報(bào)錯(cuò)

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: VectorSlicer requires indices and names to be disjoint sets of features, but they overlap. indices: [1]. names: [1:f2]
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.ml.feature.VectorSlicer.getSelectedFeatureIndices(VectorSlicer.scala:137)
    at org.apache.spark.ml.feature.VectorSlicer.transform(VectorSlicer.scala:108)
    at xingoo.mllib.VectorSlicer$.main(VectorSlicer.scala:35)
    at xingoo.mllib.VectorSlicer.main(VectorSlicer.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

注意3:如果下標(biāo)不存在

slicer.setIndices(Array(6))

如果數(shù)組越界也會(huì)報(bào)錯(cuò)

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 6
    at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3$$anonfun$apply$2.apply(VectorSlicer.scala:110)
    at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3$$anonfun$apply$2.apply(VectorSlicer.scala:110)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofInt.foreach(ArrayOps.scala:156)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
    at scala.collection.mutable.ArrayOps$ofInt.map(ArrayOps.scala:156)
    at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3.apply(VectorSlicer.scala:110)
    at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3.apply(VectorSlicer.scala:109)
    at scala.Option.map(Option.scala:145)
    at org.apache.spark.ml.feature.VectorSlicer.transform(VectorSlicer.scala:109)
    at xingoo.mllib.VectorSlicer$.main(VectorSlicer.scala:35)
    at xingoo.mllib.VectorSlicer.main(VectorSlicer.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

注意4:如果名稱不存在也會(huì)報(bào)錯(cuò)

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: getFeatureIndicesFromNames found no feature with name f8 in column StructField(userFeatures,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,false).
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.ml.util.MetadataUtils$$anonfun$getFeatureIndicesFromNames$2.apply(MetadataUtils.scala:89)
    at org.apache.spark.ml.util.MetadataUtils$$anonfun$getFeatureIndicesFromNames$2.apply(MetadataUtils.scala:88)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:108)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
    at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:108)
    at org.apache.spark.ml.util.MetadataUtils$.getFeatureIndicesFromNames(MetadataUtils.scala:88)
    at org.apache.spark.ml.feature.VectorSlicer.getSelectedFeatureIndices(VectorSlicer.scala:129)
    at org.apache.spark.ml.feature.VectorSlicer.transform(VectorSlicer.scala:108)
    at xingoo.mllib.VectorSlicer$.main(VectorSlicer.scala:35)
    at xingoo.mllib.VectorSlicer.main(VectorSlicer.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

注意5:經(jīng)過(guò)特征選擇后,特征的順序與索引和名稱的順序相同

RFormula

這個(gè)轉(zhuǎn)換器可以幫助基于R模型,自動(dòng)生成feature和label。比如說(shuō)最常用的線性回歸,在先用回歸中,我們需要把一些離散化的變量變成啞變量,即轉(zhuǎn)變成onehot編碼,使之?dāng)?shù)值化,這個(gè)我之前的文章也介紹過(guò),這里就不多說(shuō)了。

如果不是用這個(gè)RFormula,我們可能需要經(jīng)過(guò)幾個(gè)步驟:

StringIndex...OneHotEncoder...

而且每個(gè)特征都要經(jīng)過(guò)這樣的變換,非常繁瑣。有了RFormula,幾乎可以一鍵把所有的特征問(wèn)題解決。

idcoutryhourclicked
7US181.0
8CA120.0
9NZ150.0

然后我們只要寫一個(gè)類似這樣的公式clicked ~ country + hour + my_test,就代表clickedlabel,coutry、hour、my_test是三個(gè)特征

比如下面的代碼:

object RFormulaTest {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("RFormula-Test").setMaster("local[2]")    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")    var sqlContext = new SQLContext(sc)    val dataset = sqlContext.createDataFrame(Seq(
      (7, "US", 18, 1.0,"a"),
      (8, "CA", 12, 0.0,"b"),
      (9, "NZ", 15, 0.0,"a")
    )).toDF("id", "country", "hour", "clicked","my_test")    val formula = new RFormula()
      .setFormula("clicked ~ country + hour + my_test")
      .setFeaturesCol("features")
      .setLabelCol("label")    val output = formula.fit(dataset).transform(dataset)
    output.show()
    output.select("features", "label").show()
  }
}

得到的結(jié)果

+---+-------+----+-------+-------+------------------+-----+
| id|country|hour|clicked|my_test|          features|label|
+---+-------+----+-------+-------+------------------+-----+
|  7|     US|  18|    1.0|      a|[0.0,0.0,18.0,1.0]|  1.0|
|  8|     CA|  12|    0.0|      b|[1.0,0.0,12.0,0.0]|  0.0|
|  9|     NZ|  15|    0.0|      a|[0.0,1.0,15.0,1.0]|  0.0|
+---+-------+----+-------+-------+------------------+-----+

+------------------+-----+
|          features|label|
+------------------+-----+
|[0.0,0.0,18.0,1.0]|  1.0|
|[1.0,0.0,12.0,0.0]|  0.0|
|[0.0,1.0,15.0,1.0]|  0.0|
+------------------+-----+

ChiSqSelector

這個(gè)選擇器支持基于卡方檢驗(yàn)的特征選擇,卡方檢驗(yàn)是一種計(jì)算變量獨(dú)立性的檢驗(yàn)手段。具體的可以參考維基百科,最終的結(jié)論就是卡方的值越大,就是我們?cè)较胍奶卣?。因此這個(gè)選擇器就可以理解為,再計(jì)算卡方的值,最后按照這個(gè)值排序,選擇我們想要的個(gè)數(shù)的特征。

代碼也很簡(jiǎn)單

object ChiSqSelectorTest {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("ChiSqSelector-Test").setMaster("local[2]")    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")    var sqlContext = new SQLContext(sc)    val data = Seq(
      (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0),
      (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0),
      (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0)
    )    val beanRDD = sc.parallelize(data).map(t3 => Bean(t3._1,t3._2,t3._3))    val df = sqlContext.createDataFrame(beanRDD)    val selector = new ChiSqSelector()
      .setNumTopFeatures(2)
      .setFeaturesCol("features")
      .setLabelCol("clicked")
      .setOutputCol("selectedFeatures")    val result = selector.fit(df).transform(df)
    result.show()
  }  case class Bean(id:Double,features:org.apache.spark.mllib.linalg.Vector,clicked:Double){}
}

這樣得到的結(jié)果:

+---+------------------+-------+----------------+
| id|          features|clicked|selectedFeatures|
+---+------------------+-------+----------------+
|7.0|[0.0,0.0,18.0,1.0]|    1.0|      [18.0,1.0]|
|8.0|[0.0,1.0,12.0,0.0]|    0.0|      [12.0,0.0]|
|9.0|[1.0,0.0,15.0,0.1]|    0.0|      [15.0,0.1]|
+---+------------------+-------+----------------+

總結(jié)

下面總結(jié)一下三種特征選擇的使用場(chǎng)景:

  • VectorSilcer,這個(gè)選擇器適合那種有很多特征,并且明確知道自己想要哪個(gè)特征的情況。比如你有一個(gè)很全的用戶畫像系統(tǒng),每個(gè)人有成百上千個(gè)特征,但是你指向抽取用戶對(duì)電影感興趣相關(guān)的特征,因此只要手動(dòng)選擇一下就可以了。

  • RFormula,這個(gè)選擇器適合在需要做OneHotEncoder的時(shí)候,可以一個(gè)簡(jiǎn)單的代碼把所有的離散特征轉(zhuǎn)化成數(shù)值化表示。

  • ChiSqSelector,卡方檢驗(yàn)選擇器適合在你有比較多的特征,但是不知道這些特征哪個(gè)有用,哪個(gè)沒(méi)用,想要通過(guò)某種方式幫助你快速篩選特征,那么這個(gè)方法很適合。

以上的總結(jié)純屬個(gè)人看法,不代表官方做法,如果有其他的見(jiàn)解可以留言~ 多交流!

http://www.cnblogs.com/xing901022/p/7152922.html