Package pyspark :: Package mllib :: Module tree :: Class DecisionTree
[frames] | no frames]

Class DecisionTree

source code

object --+
         |
        DecisionTree


Learning algorithm for a decision tree model
for classification or regression.

EXPERIMENTAL: This is an experimental API.
              It will probably be modified for Spark v1.2.

Example usage:
>>> from numpy import array
>>> import sys
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
>>>
>>> data = [
...     LabeledPoint(0.0, [0.0]),
...     LabeledPoint(1.0, [1.0]),
...     LabeledPoint(1.0, [2.0]),
...     LabeledPoint(1.0, [3.0])
... ]
>>> categoricalFeaturesInfo = {} # no categorical features
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
...                                      categoricalFeaturesInfo=categoricalFeaturesInfo)
>>> sys.stdout.write(model)
DecisionTreeModel classifier
  If (feature 0 <= 0.5)
   Predict: 0.0
  Else (feature 0 > 0.5)
   Predict: 1.0
>>> model.predict(array([1.0])) > 0
True
>>> model.predict(array([0.0])) == 0
True
>>> sparse_data = [
...     LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
...     LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
...     LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
...     LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
...                                     categoricalFeaturesInfo=categoricalFeaturesInfo)
>>> model.predict(array([0.0, 1.0])) == 1
True
>>> model.predict(array([0.0, 0.0])) == 0
True
>>> model.predict(SparseVector(2, {1: 1.0})) == 1
True
>>> model.predict(SparseVector(2, {1: 0.0})) == 0
True

Instance Methods

Inherited from object: __delattr__, __format__, __getattribute__, __hash__, __init__, __new__, __reduce__, __reduce_ex__, __repr__, __setattr__, __sizeof__, __str__, __subclasshook__

Static Methods
 
trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100)
Train a DecisionTreeModel for classification.
source code
 
trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100)
Train a DecisionTreeModel for regression.
source code
Properties

Inherited from object: __class__

Method Details

trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100)
Static Method

source code 

Train a DecisionTreeModel for classification.

:param data: Training data: RDD of LabeledPoint.
             Labels are integers {0,1,...,numClasses}.
:param numClasses: Number of classes for classification.
:param categoricalFeaturesInfo: Map from categorical feature index
                                to number of categories.
                                Any feature not in this map
                                is treated as continuous.
:param impurity: Supported values: "entropy" or "gini"
:param maxDepth: Max depth of tree.
                 E.g., depth 0 means 1 leaf node.
                 Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel

trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100)
Static Method

source code 

Train a DecisionTreeModel for regression.

:param data: Training data: RDD of LabeledPoint.
             Labels are real numbers.
:param categoricalFeaturesInfo: Map from categorical feature index
                                to number of categories.
                                Any feature not in this map
                                is treated as continuous.
:param impurity: Supported values: "variance"
:param maxDepth: Max depth of tree.
                 E.g., depth 0 means 1 leaf node.
                 Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel