Skip to content

Commit

Permalink
Merge pull request sjwhitworth#99 from Sentimentron/params-staging
Browse files Browse the repository at this point in the history
MultiLinearSVC class weights
  • Loading branch information
sjwhitworth committed Nov 21, 2014
2 parents ec846e6 + 8fe06e7 commit aeb12bd
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 27 deletions.
63 changes: 63 additions & 0 deletions base/util_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,69 @@ func GetAttributeByName(inst FixedDataGrid, name string) Attribute {
return nil
}

// GetClassDistributionByBinaryFloatValue returns the count of each row
// which has a float value close to 0.0 or 1.0.
func GetClassDistributionByBinaryFloatValue(inst FixedDataGrid) []int {

// Get the class variable
attrs := inst.AllClassAttributes()
if len(attrs) != 1 {
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
}
if _, ok := attrs[0].(*FloatAttribute); !ok {
panic(fmt.Errorf("Class Attribute must be FloatAttribute (is %s)", attrs[0]))
}

// Get the number of class values
ret := make([]int, 2)

// Map through everything
specs := ResolveAttributes(inst, attrs)
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
index := UnpackBytesToFloat(vals[0])
if index > 0.5 {
ret[1]++
} else {
ret[0]++
}

return false, nil
})

return ret
}

// GetClassDistributionByIntegerVal returns a vector containing
// the count of each class vector (indexed by the class' system
// integer representation)
func GetClassDistributionByCategoricalValue(inst FixedDataGrid) []int {

var classAttr *CategoricalAttribute
var ok bool
// Get the class variable
attrs := inst.AllClassAttributes()
if len(attrs) != 1 {
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
}
if classAttr, ok = attrs[0].(*CategoricalAttribute); !ok {
panic(fmt.Errorf("Class Attribute must be a CategoricalAttribute (is %s)", attrs[0]))
}

// Get the number of class values
classLen := len(classAttr.GetValues())
ret := make([]int, classLen)

// Map through everything
specs := ResolveAttributes(inst, attrs)
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
index := UnpackBytesToU64(vals[0])
ret[int(index)]++
return false, nil
})

return ret
}

// GetClassDistribution returns a map containing the count of each
// class type (indexed by the class' string representation).
func GetClassDistribution(inst FixedDataGrid) map[string]int {
Expand Down
30 changes: 27 additions & 3 deletions ensemble/multisvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,38 @@ type MultiLinearSVC struct {
// whether the system solves the dual or primal SVM form, true should be used
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
// term, typically 1e-4.
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64) *MultiLinearSVC {
classifierFunc := func() base.Classifier {
ret, err := linear_models.NewLinearSVC(loss, penalty, dual, C, eps)
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, weights map[string]float64) *MultiLinearSVC {
// Set up the training parameters
params := &linear_models.LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
panic(err)
}

// Classifier creation function
classifierFunc := func(cls string) base.Classifier {
var weightVec []float64
newParams := params.Copy()
if weights != nil {
weightVec = make([]float64, 2)
for i := range weights {
if i != cls {
weightVec[0] += weights[i]
} else {
weightVec[1] = weights[i]
}
}
}
newParams.ClassWeights = weightVec

ret, err := linear_models.NewLinearSVCFromParams(newParams)
if err != nil {
panic(err)
}
return ret
}

// Return me...
return &MultiLinearSVC{
meta.NewOneVsAllModel(classifierFunc),
}
Expand Down
30 changes: 26 additions & 4 deletions ensemble/multisvc_test.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
package ensemble

import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
. "github.com/smartystreets/goconvey/convey"
"testing"
)

func TestMultiSVM(t *testing.T) {
func TestMultiSVMUnweighted(t *testing.T) {
Convey("Loading data...", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
So(err, ShouldBeNil)
X, Y := base.InstancesTrainTestSplit(inst, 0.4)

m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4)
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4, nil)
m.Fit(X)

Convey("Predictions should work...", func() {
predictions, err := m.Predict(Y)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
fmt.Println(evaluation.GetSummary(cf))
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
})
})
}

func TestMultiSVMWeighted(t *testing.T) {
Convey("Loading data...", t, func() {
weights := make(map[string]float64)
weights["Finance"] = 0.1739
weights["Tech"] = 0.0750
weights["Politics"] = 0.4928

inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
So(err, ShouldBeNil)
X, Y := base.InstancesTrainTestSplit(inst, 0.4)

m := NewMultiLinearSVC("l1", "l2", true, 0.62, 1e-4, weights)
m.Fit(X)

Convey("Predictions should work...", func() {
predictions, err := m.Predict(Y)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
})
})
}
164 changes: 149 additions & 15 deletions linear_models/linearsvc.go
Original file line number Diff line number Diff line change
@@ -1,54 +1,186 @@
package linear_models

import "C"

import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"unsafe"
)

type LinearSVC struct {
param *Parameter
model *Model
// LinearSVCParams represnts all available LinearSVC options.
//
// SolverKind: can be linear_models.L2_L1LOSS_SVC_DUAL,
// L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L1R_L2LOSS_SVC.
// It must be set via SetKindFromStrings.
//
// ClassWeights describes how each class is weighted, and can
// be used in class-imabalanced scenarios. If this is nil, then
// all classes will be weighted the same unless WeightClassesAutomatically
// is True.
//
// C is a float64 represnenting the misclassification penalty.
//
// Eps is a float64 convergence threshold.
//
// Dual indicates whether the solution is primary or dual.
type LinearSVCParams struct {
SolverType int
ClassWeights []float64
C float64
Eps float64
WeightClassesAutomatically bool
Dual bool
}

func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
solver_type := 0
// Copy return s a copy of these parameters
func (p *LinearSVCParams) Copy() *LinearSVCParams {
ret := &LinearSVCParams{
p.SolverType,
nil,
p.C,
p.Eps,
p.WeightClassesAutomatically,
p.Dual,
}
if p.ClassWeights != nil {
ret.ClassWeights = make([]float64, len(p.ClassWeights))
copy(ret.ClassWeights, p.ClassWeights)
}
return ret
}

// SetKindFromStrings configures the solver kind from strings.
// Penalty and Loss parameters can either be l1 or l2.
func (p *LinearSVCParams) SetKindFromStrings(loss, penalty string) error {
var ret error
p.SolverType = 0
// Loss validation
if loss == "l1" {
} else if loss == "l2" {
} else {
return fmt.Errorf("loss must be \"l1\" or \"l2\"")
}

// Penalty validation
if penalty == "l2" {
if loss == "l1" {
if dual {
solver_type = L2R_L1LOSS_SVC_DUAL
if !p.Dual {
ret = fmt.Errorf("Important: changed to dual form")
}
p.SolverType = L2R_L1LOSS_SVC_DUAL
p.Dual = true
} else {
if dual {
solver_type = L2R_L2LOSS_SVC_DUAL
if p.Dual {
p.SolverType = L2R_L2LOSS_SVC_DUAL
} else {
solver_type = L2R_L2LOSS_SVC
p.SolverType = L2R_L2LOSS_SVC
}
}
} else if penalty == "l1" {
if loss == "l2" {
if !dual {
solver_type = L1R_L2LOSS_SVC
if p.Dual {
ret = fmt.Errorf("Important: changed to primary form")
}
p.Dual = false
p.SolverType = L1R_L2LOSS_SVC
} else {
return fmt.Errorf("Must have L2 loss with L1 penalty")
}
} else {
return fmt.Errorf("Penalty must be \"l1\" or \"l2\"")
}
if solver_type == 0 {
panic("Parameter combination")

// Final validation
if p.SolverType == 0 {
return fmt.Errorf("Invalid parameter combination")
}
return ret
}

// convertToNativeFormat converts the LinearSVCParams given into a format
// for liblinear.
func (p *LinearSVCParams) convertToNativeFormat() *Parameter {
return NewParameter(p.SolverType, p.C, p.Eps)
}

// LinearSVC represents a linear support-vector classifier.
type LinearSVC struct {
param *Parameter
model *Model
Param *LinearSVCParams
}

// NewLinearSVC creates a new support classifier.
//
// loss and penalty: see LinearSVCParams#SetKindFromString
//
// dual: see LinearSVCParams
//
// eps: see LinearSVCParams
//
// C: see LinearSVCParams
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {

// Convert and check parameters
params := &LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
return nil, err
}

return NewLinearSVCFromParams(params)
}

// NewLinearSVCFromParams constructs a LinearSVC from the given LinearSVCParams structure.
func NewLinearSVCFromParams(params *LinearSVCParams) (*LinearSVC, error) {
// Construct model
lr := LinearSVC{}
lr.param = NewParameter(solver_type, C, eps)
lr.param = params.convertToNativeFormat()
lr.Param = params
lr.model = nil
return &lr, nil
}

// Fit automatically weights the class vector (if configured to do so)
// converts the FixedDataGrid into the right format and trains the model.
func (lr *LinearSVC) Fit(X base.FixedDataGrid) error {

var weightVec []float64
var weightClasses []C.int

// Creates the class weighting
if lr.Param.ClassWeights == nil {
if lr.Param.WeightClassesAutomatically {
weightVec = generateClassWeightVectorFromDist(X)
} else {
weightVec = generateClassWeightVectorFromFixed(X)
}
} else {
weightVec = lr.Param.ClassWeights
}

weightClasses = make([]C.int, len(weightVec))
for i := range weightVec {
weightClasses[i] = C.int(i)
}

// Convert the problem
problemVec := convertInstancesToProblemVec(X)
labelVec := convertInstancesToLabelVec(X)

// Train
prob := NewProblem(problemVec, labelVec, 0)
lr.param.c_param.nr_weight = C.int(len(weightVec))
lr.param.c_param.weight_label = &(weightClasses[0])
lr.param.c_param.weight = (*C.double)(unsafe.Pointer(&weightVec[0]))

// lr.param.weights = (*C.double)unsafe.Pointer(&(weightVec[0]));
lr.model = Train(prob, lr.param)
return nil
}

// Predict issues predictions from a trained LinearSVC.
func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {

// Only support 1 class Attribute
Expand All @@ -59,6 +191,7 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
// Generate return structure
ret := base.GeneratePredictionVector(X)
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)

// Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
Expand All @@ -78,6 +211,7 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
return ret, nil
}

// String return a humaan-readable version.
func (lr *LinearSVC) String() string {
return "LogisticSVC"
}
Loading

0 comments on commit aeb12bd

Please sign in to comment.