-
Notifications
You must be signed in to change notification settings - Fork 3
/
myTMVA0.py
79 lines (62 loc) · 2.45 KB
/
myTMVA0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Example code for TMVA in pyROOT
# By Andrea Holzner
# http://aholzner.wordpress.com/2011/08/27/a-tmva-example-in-pyroot/
import ROOT
# create a TNtuple
ntuple = ROOT.TNtuple("ntuple","ntuple","x:y:signal")
# generate 'signal' and 'background' distributions
for i in range(10000):
# throw a signal event centered at (1,1)
ntuple.Fill(ROOT.gRandom.Gaus(1,1), # x
ROOT.gRandom.Gaus(1,1), # y
1) # signal
# throw a background event centered at (-1,-1)
ntuple.Fill(ROOT.gRandom.Gaus(-1,1), # x
ROOT.gRandom.Gaus(-1,1), # y
0) # background
# keeps objects otherwise removed by garbage collected in a list
gcSaver = []
# create a new TCanvas
gcSaver.append(ROOT.TCanvas())
# draw an empty 2D histogram for the axes
histo = ROOT.TH2F("histo","",1,-5,5,1,-5,5)
histo.Draw()
# draw the signal events in red
ntuple.SetMarkerColor(ROOT.kRed)
ntuple.Draw("y:x","signal > 0.5","same")
# draw the background events in blue
ntuple.SetMarkerColor(ROOT.kBlue)
ntuple.Draw("y:x","signal <= 0.5","same")
ROOT.c1.SaveAs("scatter.png")
factory.AddVariable("x","F")
factory.AddVariable("y","F")
factory.AddSignalTree(ntuple)
factory.AddBackgroundTree(ntuple)
# cuts defining the signal and background sample
sigCut = ROOT.TCut("signal > 0.5")
bgCut = ROOT.TCut("signal <= 0.5")
factory.PrepareTrainingAndTestTree(sigCut, # signal events
bgCut, # background events
":".join([
"nTrain_Signal=0",
"nTrain_Background=0",
"SplitMode=Random",
"NormMode=NumEvents",
"!V"
]))
method = factory.BookMethod(ROOT.TMVA.Types.kBDT, "BDT",
":".join([
"!H",
"!V",
"NTrees=850",
"nEventsMin=150",
"MaxDepth=3",
"BoostType=AdaBoost",
"AdaBoostBeta=0.5",
"SeparationType=GiniIndex",
"nCuts=20",
"PruneMethod=NoPruning",
]))
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()