1 | from inspect import getargspec |
---|
2 | from ROOT import TLorentzVector,TVector3,TChain,TClass,TDatabasePDG |
---|
3 | from datetime import datetime |
---|
4 | from collections import Iterable |
---|
5 | from types import StringTypes |
---|
6 | from os import path |
---|
7 | |
---|
8 | class AnalysisEvent(TChain): |
---|
9 | """A class that complements fwlite::Events with analysis facilities. |
---|
10 | The class provides the following additional functionalities: |
---|
11 | 1. instrumentation for event weight |
---|
12 | A set of weight classes can be defined, and the event weight |
---|
13 | is computed and cached using those. |
---|
14 | 2. list of event products used in the analysis |
---|
15 | It makes the iteration faster by only enabling required branches. |
---|
16 | 3. a list of "producers" of analysis high-level quantities |
---|
17 | It allows to run "analysis on demand", by automatically running |
---|
18 | the defined producers to fill the cache, and later use that one. |
---|
19 | 4. a volatile dictionary |
---|
20 | It allows to use the event as an heterogenous container for |
---|
21 | any analysis product. The event is properly reset when iterating |
---|
22 | to the next event. |
---|
23 | """ |
---|
24 | |
---|
25 | def __init__(self, inputFiles = '', maxEvents=0): |
---|
26 | """Initialize the AnalysisEvent like a standard Event, plus additional features.""" |
---|
27 | # initialization of base functionalities |
---|
28 | TChain.__init__(self,"Delphes","Delphes") |
---|
29 | if isinstance(inputFiles,Iterable) and not isinstance(inputFiles,StringTypes): |
---|
30 | for thefile in inputFiles: |
---|
31 | if path.isfile(thefile): |
---|
32 | self.AddFile(thefile) |
---|
33 | else: |
---|
34 | print "Warning: ",thefile," do not exist." |
---|
35 | elif isinstance(inputFiles,StringTypes): |
---|
36 | thefile = inputFiles |
---|
37 | if path.isfile(thefile): |
---|
38 | self.AddFile(thefile) |
---|
39 | else: |
---|
40 | print "Warning: ",thefile," do not exist." |
---|
41 | else: |
---|
42 | print "Warning: invalid inputFiles" |
---|
43 | self.BuildIndex("Event[0].Number") |
---|
44 | self.SetBranchStatus("*",0) |
---|
45 | self._eventCounts = 0 |
---|
46 | self._maxEvents = maxEvents |
---|
47 | # additional features: |
---|
48 | # 1. instrumentation for event weight |
---|
49 | self._weightCache = {} |
---|
50 | self._weightEngines = {} |
---|
51 | # 2. a list of event products used in the analysis |
---|
52 | self._collections = {} |
---|
53 | self._branches = dict((b,False) for b in map(lambda b:b.GetName(),self.GetListOfBranches())) |
---|
54 | # 3. a list of "producers" of analysis high-level quantities |
---|
55 | self._producers = {} |
---|
56 | # 4. volatile dictionary. User can add any quantity to the event and it will be |
---|
57 | # properly erased in the iteration step. |
---|
58 | self.__dict__["vardict"] = {} |
---|
59 | |
---|
60 | def addWeight(self, name, weightClass): |
---|
61 | """Declare a new class (engine) to compute the weights. |
---|
62 | weightClass must have a weight() method returning a float.""" |
---|
63 | if name in self._weightEngines: |
---|
64 | raise KeyError("%s weight engine is already declared" % name) |
---|
65 | self._weightEngines[name] = weightClass |
---|
66 | self._weightCache.clear() |
---|
67 | |
---|
68 | def delWeight(self, name): |
---|
69 | """Remove one weight engine from the internal list.""" |
---|
70 | # just to clean the dictionnary |
---|
71 | del self._weightEngines[name] |
---|
72 | self._weightCache.clear() |
---|
73 | |
---|
74 | def weight(self, weightList=None, **kwargs): |
---|
75 | """Return the event weight. Arguments: |
---|
76 | * weightList is the list of engines to use, as a list of strings. |
---|
77 | Default: all defined engines. |
---|
78 | * the other named arguments are forwarded to the engines. |
---|
79 | The output is the product of the selected individual weights.""" |
---|
80 | # first check in the cache if the result is there already |
---|
81 | if weightList is None: |
---|
82 | weightList=self._weightEngines.keys() |
---|
83 | kwargs["weightList"] = weightList |
---|
84 | # compute the weight or use the cached value |
---|
85 | myhash = self._dicthash(kwargs) |
---|
86 | if not myhash in self._weightCache : |
---|
87 | w = 1. |
---|
88 | for weightElement in weightList: |
---|
89 | engine = self._weightEngines[weightElement] |
---|
90 | engineArgs = getargspec(engine.weight).args |
---|
91 | subargs = dict((k,v) for k,v in kwargs.iteritems() if k in engineArgs) |
---|
92 | w *= self._weightCache.setdefault("weightElement:%s # %s" %(weightElement,self._dicthash(subargs)),engine.weight(self,**subargs)) |
---|
93 | self._weightCache[myhash] = w |
---|
94 | return self._weightCache[myhash] |
---|
95 | |
---|
96 | def addCollection(self, name, inputTag): |
---|
97 | """Register an event collection as used by the analysis. |
---|
98 | Example: addCollection("myjets","jets") |
---|
99 | Note that the direct access to the branch is still possible but unsafe.""" |
---|
100 | if name in self._collections: |
---|
101 | raise KeyError("%r collection is already declared", name) |
---|
102 | if name in self._producers: |
---|
103 | raise KeyError("%r is already declared as a producer", name) |
---|
104 | if hasattr(self,name): |
---|
105 | raise AttributeError("%r object already has attribute %r" % (type(self).__name__, name)) |
---|
106 | if inputTag not in self._branches: |
---|
107 | raise AttributeError("%r object has no branch %r" % (type(self).__name__, inputTag)) |
---|
108 | self._collections[name] = inputTag |
---|
109 | self.SetBranchStatus(inputTag+"*",1) |
---|
110 | self._branches[inputTag] = True |
---|
111 | |
---|
112 | def removeCollection(self,name): |
---|
113 | """Forget about the named event collection. |
---|
114 | This method will delete both the product from the cache (if any) and the definition. |
---|
115 | To simply clear the cache, use "del event.name" instead. """ |
---|
116 | self.SetBranchStatus(self._collections[name]+"*",0) |
---|
117 | self._branches[self._collections[name]] = False |
---|
118 | del self._collections[name] |
---|
119 | if name in self.vardict: |
---|
120 | delattr(self,name) |
---|
121 | |
---|
122 | def getCollection(self,name): |
---|
123 | """Retrieve the event product or return the cached collection. |
---|
124 | Note that the prefered way to get the collection is instead to access the "event.name" attribute.""" |
---|
125 | if not name in self._collections: |
---|
126 | raise AttributeError("%r object has no attribute %r" % (type(self).__name__,name)) |
---|
127 | if not name in self.vardict: |
---|
128 | self.vardict[name] = TChain.__getattr__(self,self._collections[name]) |
---|
129 | return getattr(self,name) |
---|
130 | |
---|
131 | def addProducer(self,name,producer,**kwargs): |
---|
132 | """Register a producer to create new high-level analysis objects.""" |
---|
133 | # sanity checks |
---|
134 | if name in self._producers: |
---|
135 | raise KeyError("%r producer is already declared", name) |
---|
136 | if name in self._collections: |
---|
137 | raise KeyError("%r is already declared as a collection", name) |
---|
138 | if hasattr(self,name): |
---|
139 | raise AttributeError("%r object already has attribute %r" % (type(self).__name__, attr)) |
---|
140 | # remove name and producer from kwargs |
---|
141 | if "name" in kwargs: del kwargs["name"] |
---|
142 | if "producer" in kwargs: del kwargs["producer"] |
---|
143 | # store |
---|
144 | self._producers[name] = (producer,kwargs) |
---|
145 | |
---|
146 | def removeProducer(self,name): |
---|
147 | """Forget about the producer. |
---|
148 | This method will delete both the product from the cache (if any) and the producer. |
---|
149 | To simply clear the cache, use "del event.name" instead.""" |
---|
150 | del self._producers[name] |
---|
151 | if name in self.vardict: |
---|
152 | delattr(self,name) |
---|
153 | |
---|
154 | def event(self): |
---|
155 | """Event number""" |
---|
156 | if self._branches["Event"]: |
---|
157 | return self.Event.At(0).Number |
---|
158 | else: |
---|
159 | return 0 |
---|
160 | |
---|
161 | def to(self,event): |
---|
162 | """Jump to some event""" |
---|
163 | self.GetEntryWithIndex(event) |
---|
164 | |
---|
165 | def __getitem__(self,index): |
---|
166 | """Jump to some event""" |
---|
167 | self.GetEntryWithIndex(index) |
---|
168 | return self |
---|
169 | |
---|
170 | def __iter__ (self): |
---|
171 | """Iterator""" |
---|
172 | self._eventCounts = 0 |
---|
173 | while self.GetEntry(self._eventCounts): |
---|
174 | self.vardict.clear() |
---|
175 | self._weightCache.clear() |
---|
176 | yield self |
---|
177 | self._eventCounts += 1 |
---|
178 | if self._maxEvents > 0 and self._eventCounts >= self._maxEvents: |
---|
179 | break |
---|
180 | |
---|
181 | def __getattr__(self, attr): |
---|
182 | """Overloaded getter to handle properly: |
---|
183 | - volatile analysis objects |
---|
184 | - event collections |
---|
185 | - data producers""" |
---|
186 | if attr in self.__dict__["vardict"]: |
---|
187 | return self.vardict[attr] |
---|
188 | if attr in self._collections: |
---|
189 | return self.vardict.setdefault(attr, TChain.__getattr__(self,self._collections[attr])) |
---|
190 | if attr in self._producers: |
---|
191 | return self.vardict.setdefault(attr, self._producers[attr][0](self, **self._producers[attr][1])) |
---|
192 | return TChain.__getattr__(self,attr) |
---|
193 | |
---|
194 | def __setattr__(self, name, value): |
---|
195 | """Overloaded setter that puts any new attribute in the volatile dict.""" |
---|
196 | if name in self.__dict__ or not "vardict" in self.__dict__ or name[0]=='_': |
---|
197 | self.__dict__[name] = value |
---|
198 | else: |
---|
199 | if name in self._collections or name in self._producers: |
---|
200 | raise AttributeError("%r object %r attribute is read-only (event collection)" % (type(self).__name__, name)) |
---|
201 | self.vardict[name] = value |
---|
202 | |
---|
203 | def __delattr__(self, name): |
---|
204 | """Overloaded del method to handle the volatile internal dictionary.""" |
---|
205 | if name=="vardict": |
---|
206 | raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name)) |
---|
207 | if name in self.__dict__: |
---|
208 | del self.__dict__[name] |
---|
209 | elif name in self.vardict: |
---|
210 | del self.vardict[name] |
---|
211 | else: |
---|
212 | raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name)) |
---|
213 | |
---|
214 | def _dicthash(self,dict): |
---|
215 | return (lambda d,j='=',s=';': s.join([j.join((str(k),str(v))) for k,v in d.iteritems()]))(dict) |
---|
216 | |
---|
217 | def __str__(self): |
---|
218 | """Event text dump.""" |
---|
219 | dictjoin = lambda d,j=' => ',s='\n': s.join([j.join((str(k),str(v))) for k,v in d.iteritems()]) |
---|
220 | mystring = "=================================================================\n" |
---|
221 | # general information |
---|
222 | if self._branches["Event"]: |
---|
223 | mystring += str(self.Event.At(0)) |
---|
224 | else: |
---|
225 | mystring += "Event %d\n" % self.GetReadEvent() |
---|
226 | mystring += "-----------------------------------------------------------------\n" |
---|
227 | # weights |
---|
228 | if len(self._weightCache)==0: |
---|
229 | mystring += "No weight computed so far. Default weight is %f.\n" % self.weight() |
---|
230 | else: |
---|
231 | mystring += "Weights:\n" |
---|
232 | mystring += dictjoin(self._weightCache) |
---|
233 | mystring += "\n-----------------------------------------------------------------\n" |
---|
234 | # list the collections |
---|
235 | mystring += "Collections:\n" |
---|
236 | for colname in self._collections.keys(): |
---|
237 | collection = self.getCollection(colname) |
---|
238 | if collection.GetEntries()>0: |
---|
239 | if collection.At(0).IsA()==TClass.GetClass("HepMCEvent"): |
---|
240 | pass |
---|
241 | else: |
---|
242 | mystring += "*** %s has %d element(s)\n" % (colname,collection.GetEntries()) |
---|
243 | mystring += reduce(lambda a,b: a+b,map(str,collection)) |
---|
244 | mystring += "\n-----------------------------------------------------------------\n" |
---|
245 | # list the registered producers |
---|
246 | mystring += "Producers:\n" |
---|
247 | mystring += dictjoin(self._producers) |
---|
248 | mystring += "\n-----------------------------------------------------------------\n" |
---|
249 | # list the content of vardict, excluding collections |
---|
250 | mystring += "Content of the cache:\n" |
---|
251 | for k, v in self.vardict.iteritems(): |
---|
252 | if k in self._collections.keys() : continue |
---|
253 | if isinstance(v,Iterable) and not isinstance(v,StringTypes): |
---|
254 | try: |
---|
255 | thisstring = "%s => vector of %d objects(s)\n" % (k,len(v)) |
---|
256 | except: |
---|
257 | mystring += "%s => %s\n"%(k,str(v)) |
---|
258 | else: |
---|
259 | try: |
---|
260 | for it,vec in enumerate(v): |
---|
261 | thisstring += "%s[%d] = %s\n"%(k,it,str(vec)) |
---|
262 | except: |
---|
263 | mystring += "%s => %s\n"%(k,str(v)) |
---|
264 | else: |
---|
265 | mystring += thisstring |
---|
266 | else: |
---|
267 | mystring += "%s => %s\n"%(k,str(v)) |
---|
268 | return mystring |
---|
269 | |
---|
270 | def decayTree(self, genparticles): |
---|
271 | db = TDatabasePDG() |
---|
272 | theString = "" |
---|
273 | for part in genparticles: |
---|
274 | if part.M1==-1 and part.M2==-1: |
---|
275 | theString += part.printDecay(db, genparticles) |
---|
276 | return theString |
---|
277 | |
---|