source: branches/Release2.0/python/asapfitter.py@ 1123

Last change on this file since 1123 was 1039, checked in by mar637, 19 years ago

minor fix to printing of areas by component. The sum of all areas was printed instead of the area of the individual components.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 17.7 KB
Line 
1import _asap
2from asap import rcParams
3from asap import print_log
4
5class fitter:
6 """
7 The fitting class for ASAP.
8 """
9
10 def __init__(self):
11 """
12 Create a fitter object. No state is set.
13 """
14 self.fitter = _asap.fitter()
15 self.x = None
16 self.y = None
17 self.mask = None
18 self.fitfunc = None
19 self.fitfuncs = None
20 self.fitted = False
21 self.data = None
22 self.components = 0
23 self._fittedrow = 0
24 self._p = None
25 self._selection = None
26
27 def set_data(self, xdat, ydat, mask=None):
28 """
29 Set the absissa and ordinate for the fit. Also set the mask
30 indicationg valid points.
31 This can be used for data vectors retrieved from a scantable.
32 For scantable fitting use 'fitter.set_scan(scan, mask)'.
33 Parameters:
34 xdat: the abcissa values
35 ydat: the ordinate values
36 mask: an optional mask
37
38 """
39 self.fitted = False
40 self.x = xdat
41 self.y = ydat
42 if mask == None:
43 from numarray import ones
44 self.mask = ones(len(xdat))
45 else:
46 self.mask = mask
47 return
48
49 def set_scan(self, thescan=None, mask=None):
50 """
51 Set the 'data' (a scantable) of the fitter.
52 Parameters:
53 thescan: a scantable
54 mask: a msk retireved from the scantable
55 """
56 if not thescan:
57 msg = "Please give a correct scan"
58 if rcParams['verbose']:
59 print msg
60 return
61 else:
62 raise TypeError(msg)
63 self.fitted = False
64 self.data = thescan
65 if mask is None:
66 from numarray import ones
67 self.mask = ones(self.data.nchan())
68 else:
69 self.mask = mask
70 return
71
72 def set_function(self, **kwargs):
73 """
74 Set the function to be fit.
75 Parameters:
76 poly: use a polynomial of the order given
77 gauss: fit the number of gaussian specified
78 Example:
79 fitter.set_function(gauss=2) # will fit two gaussians
80 fitter.set_function(poly=3) # will fit a 3rd order polynomial
81 """
82 #default poly order 0
83 n=0
84 if kwargs.has_key('poly'):
85 self.fitfunc = 'poly'
86 n = kwargs.get('poly')
87 self.components = [n]
88 elif kwargs.has_key('gauss'):
89 n = kwargs.get('gauss')
90 self.fitfunc = 'gauss'
91 self.fitfuncs = [ 'gauss' for i in range(n) ]
92 self.components = [ 3 for i in range(n) ]
93 else:
94 msg = "Invalid function type."
95 if rcParams['verbose']:
96 print msg
97 return
98 else:
99 raise TypeError(msg)
100
101 self.fitter.setexpression(self.fitfunc,n)
102 return
103
104 def fit(self, row=0):
105 """
106 Execute the actual fitting process. All the state has to be set.
107 Parameters:
108 row: specify the row in the scantable
109 Example:
110 s = scantable('myscan.asap')
111 s.set_cursor(thepol=1) # select second pol
112 f = fitter()
113 f.set_scan(s)
114 f.set_function(poly=0)
115 f.fit(row=0) # fit first row
116 """
117 if ((self.x is None or self.y is None) and self.data is None) \
118 or self.fitfunc is None:
119 msg = "Fitter not yet initialised. Please set data & fit function"
120 if rcParams['verbose']:
121 print msg
122 return
123 else:
124 raise RuntimeError(msg)
125
126 else:
127 if self.data is not None:
128 self.x = self.data._getabcissa(row)
129 self.y = self.data._getspectrum(row)
130 from asap import asaplog
131 asaplog.push("Fitting:")
132 i = row
133 out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),self.data.getbeam(i),self.data.getif(i),self.data.getpol(i), self.data.getcycle(i))
134 asaplog.push(out)
135 self.fitter.setdata(self.x, self.y, self.mask)
136 if self.fitfunc == 'gauss':
137 ps = self.fitter.getparameters()
138 if len(ps) == 0:
139 self.fitter.estimate()
140 try:
141 self.fitter.fit()
142 except RuntimeError, msg:
143 if rcParams['verbose']:
144 print msg
145 else:
146 raise
147 self._fittedrow = row
148 self.fitted = True
149 print_log()
150 return
151
152 def store_fit(self):
153 """
154 Store the fit parameters in the scantable.
155 """
156 if self.fitted and self.data is not None:
157 pars = list(self.fitter.getparameters())
158 fixed = list(self.fitter.getfixedparameters())
159 from asap.asapfit import asapfit
160 fit = asapfit()
161 fit.setparameters(pars)
162 fit.setfixedparameters(fixed)
163 fit.setfunctions(self.fitfuncs)
164 fit.setcomponents(self.components)
165 fit.setframeinfo(self.data._getcoordinfo())
166 self.data._addfit(fit,self._fittedrow)
167
168 #def set_parameters(self, params, fixed=None, component=None):
169 def set_parameters(self,*args,**kwargs):
170 """
171 Set the parameters to be fitted.
172 Parameters:
173 params: a vector of parameters
174 fixed: a vector of which parameters are to be held fixed
175 (default is none)
176 component: in case of multiple gaussians, the index of the
177 component
178 """
179 component = None
180 fixed = None
181 params = None
182
183 if len(args) and isinstance(args[0],dict):
184 kwargs = args[0]
185 if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
186 if kwargs.has_key("params"): params = kwargs["params"]
187 if len(args) == 2 and isinstance(args[1], int):
188 component = args[1]
189 if self.fitfunc is None:
190 msg = "Please specify a fitting function first."
191 if rcParams['verbose']:
192 print msg
193 return
194 else:
195 raise RuntimeError(msg)
196 if self.fitfunc == "gauss" and component is not None:
197 if not self.fitted and sum(self.fitter.getparameters()) == 0:
198 from numarray import zeros
199 pars = list(zeros(len(self.components)*3))
200 fxd = list(zeros(len(pars)))
201 else:
202 pars = list(self.fitter.getparameters())
203 fxd = list(self.fitter.getfixedparameters())
204 i = 3*component
205 pars[i:i+3] = params
206 fxd[i:i+3] = fixed
207 params = pars
208 fixed = fxd
209 self.fitter.setparameters(params)
210 if fixed is not None:
211 self.fitter.setfixedparameters(fixed)
212 print_log()
213 return
214
215 def set_gauss_parameters(self, peak, centre, fhwm,
216 peakfixed=0, centerfixed=0,
217 fhwmfixed=0,
218 component=0):
219 """
220 Set the Parameters of a 'Gaussian' component, set with set_function.
221 Parameters:
222 peak, centre, fhwm: The gaussian parameters
223 peakfixed,
224 centerfixed,
225 fhwmfixed: Optional parameters to indicate if
226 the paramters should be held fixed during
227 the fitting process. The default is to keep
228 all parameters flexible.
229 component: The number of the component (Default is the
230 component 0)
231 """
232 if self.fitfunc != "gauss":
233 msg = "Function only operates on Gaussian components."
234 if rcParams['verbose']:
235 print msg
236 return
237 else:
238 raise ValueError(msg)
239 if 0 <= component < len(self.components):
240 d = {'params':[peak, centre, fhwm],
241 'fixed':[peakfixed, centerfixed, fhwmfixed]}
242 self.set_parameters(d, component)
243 else:
244 msg = "Please select a valid component."
245 if rcParams['verbose']:
246 print msg
247 return
248 else:
249 raise ValueError(msg)
250
251 def get_area(self, component=None):
252 """
253 Return the area under the fitted gaussian component.
254 Parameters:
255 component: the gaussian component selection,
256 default (None) is the sum of all components
257 Note:
258 This will only work for gaussian fits.
259 """
260 if not self.fitted: return
261 if self.fitfunc == "gauss":
262 pars = list(self.fitter.getparameters())
263 from math import log,pi,sqrt
264 fac = sqrt(pi/log(16.0))
265 areas = []
266 for i in range(len(self.components)):
267 j = i*3
268 cpars = pars[j:j+3]
269 areas.append(fac * cpars[0] * cpars[2])
270 else:
271 return None
272 if component is not None:
273 return areas[component]
274 else:
275 return sum(areas)
276
277 def get_parameters(self, component=None):
278 """
279 Return the fit paramters.
280 Parameters:
281 component: get the parameters for the specified component
282 only, default is all components
283 """
284 if not self.fitted:
285 msg = "Not yet fitted."
286 if rcParams['verbose']:
287 print msg
288 return
289 else:
290 raise RuntimeError(msg)
291 pars = list(self.fitter.getparameters())
292 fixed = list(self.fitter.getfixedparameters())
293 area = []
294 if component is not None:
295 if self.fitfunc == "gauss":
296 i = 3*component
297 cpars = pars[i:i+3]
298 cfixed = fixed[i:i+3]
299 a = self.get_area(component)
300 area = [a for i in range(3)]
301 else:
302 cpars = pars
303 cfixed = fixed
304 else:
305 cpars = pars
306 cfixed = fixed
307 if self.fitfunc == "gauss":
308 for c in range(len(self.components)):
309 a = self.get_area(c)
310 area += [a for i in range(3)]
311
312 fpars = self._format_pars(cpars, cfixed, area)
313 if rcParams['verbose']:
314 print fpars
315 return {'params':cpars, 'fixed':cfixed, 'formatted': fpars}
316
317 def _format_pars(self, pars, fixed, area):
318 out = ''
319 if self.fitfunc == 'poly':
320 c = 0
321 for i in range(len(pars)):
322 fix = ""
323 if fixed[i]: fix = "(fixed)"
324 out += ' p%d%s= %3.3f,' % (c,fix,pars[i])
325 c+=1
326 out = out[:-1] # remove trailing ','
327 elif self.fitfunc == 'gauss':
328 i = 0
329 c = 0
330 aunit = ''
331 ounit = ''
332 if self.data:
333 aunit = self.data.get_unit()
334 ounit = self.data.get_fluxunit()
335 while i < len(pars):
336 if len(area):
337 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
338 else:
339 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
340 c+=1
341 i+=3
342 return out
343
344 def get_estimate(self):
345 """
346 Return the parameter estimates (for non-linear functions).
347 """
348 pars = self.fitter.getestimate()
349 fixed = self.fitter.getfixedparameters()
350 if rcParams['verbose']:
351 print self._format_pars(pars,fixed,None)
352 return pars
353
354 def get_residual(self):
355 """
356 Return the residual of the fit.
357 """
358 if not self.fitted:
359 msg = "Not yet fitted."
360 if rcParams['verbose']:
361 print msg
362 return
363 else:
364 raise RuntimeError(msg)
365 return self.fitter.getresidual()
366
367 def get_chi2(self):
368 """
369 Return chi^2.
370 """
371 if not self.fitted:
372 msg = "Not yet fitted."
373 if rcParams['verbose']:
374 print msg
375 return
376 else:
377 raise RuntimeError(msg)
378 ch2 = self.fitter.getchi2()
379 if rcParams['verbose']:
380 print 'Chi^2 = %3.3f' % (ch2)
381 return ch2
382
383 def get_fit(self):
384 """
385 Return the fitted ordinate values.
386 """
387 if not self.fitted:
388 msg = "Not yet fitted."
389 if rcParams['verbose']:
390 print msg
391 return
392 else:
393 raise RuntimeError(msg)
394 return self.fitter.getfit()
395
396 def commit(self):
397 """
398 Return a new scan where the fits have been commited (subtracted)
399 """
400 if not self.fitted:
401 print "Not yet fitted."
402 msg = "Not yet fitted."
403 if rcParams['verbose']:
404 print msg
405 return
406 else:
407 raise RuntimeError(msg)
408 from asap import scantable
409 if not isinstance(self.data, scantable):
410 msg = "Not a scantable"
411 if rcParams['verbose']:
412 print msg
413 return
414 else:
415 raise TypeError(msg)
416 scan = self.data.copy()
417 scan._setspectrum(self.fitter.getresidual())
418 print_log()
419
420 def plot(self, residual=False, components=None, plotparms=False, filename=None):
421 """
422 Plot the last fit.
423 Parameters:
424 residual: an optional parameter indicating if the residual
425 should be plotted (default 'False')
426 components: a list of components to plot, e.g [0,1],
427 -1 plots the total fit. Default is to only
428 plot the total fit.
429 plotparms: Inidicates if the parameter values should be present
430 on the plot
431 """
432 if not self.fitted:
433 return
434 if not self._p or self._p.is_dead:
435 if rcParams['plotter.gui']:
436 from asap.asaplotgui import asaplotgui as asaplot
437 else:
438 from asap.asaplot import asaplot
439 self._p = asaplot()
440 self._p.hold()
441 self._p.clear()
442 self._p.set_panels()
443 self._p.palette(0)
444 tlab = 'Spectrum'
445 xlab = 'Abcissa'
446 ylab = 'Ordinate'
447 m = None
448 if self.data:
449 tlab = self.data._getsourcename(self._fittedrow)
450 xlab = self.data._getabcissalabel(self._fittedrow)
451 m = self.data._getmask(self._fittedrow)
452 ylab = self.data._get_ordinate_label()
453
454 colours = ["#777777","#bbbbbb","red","orange","purple","green","magenta", "cyan"]
455 self._p.palette(0,colours)
456 self._p.set_line(label='Spectrum')
457 self._p.plot(self.x, self.y, m)
458 if residual:
459 self._p.palette(1)
460 self._p.set_line(label='Residual')
461 self._p.plot(self.x, self.get_residual(), m)
462 self._p.palette(2)
463 if components is not None:
464 cs = components
465 if isinstance(components,int): cs = [components]
466 if plotparms:
467 self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
468 n = len(self.components)
469 self._p.palette(3)
470 for c in cs:
471 if 0 <= c < n:
472 lab = self.fitfuncs[c]+str(c)
473 self._p.set_line(label=lab)
474 self._p.plot(self.x, self.fitter.evaluate(c), m)
475 elif c == -1:
476 self._p.palette(2)
477 self._p.set_line(label="Total Fit")
478 self._p.plot(self.x, self.get_fit(), m)
479 else:
480 self._p.palette(2)
481 self._p.set_line(label='Fit')
482 self._p.plot(self.x, self.get_fit(), m)
483 xlim=[min(self.x),max(self.x)]
484 self._p.axes.set_xlim(xlim)
485 self._p.set_axes('xlabel',xlab)
486 self._p.set_axes('ylabel',ylab)
487 self._p.set_axes('title',tlab)
488 self._p.release()
489 if (not rcParams['plotter.gui']):
490 self._p.save(filename)
491 print_log()
492
493 def auto_fit(self, insitu=None):
494 """
495 Return a scan where the function is applied to all rows for
496 all Beams/IFs/Pols.
497
498 """
499 from asap import scantable
500 if not isinstance(self.data, scantable) :
501 msg = "Data is not a scantable"
502 if rcParams['verbose']:
503 print msg
504 return
505 else:
506 raise TypeError(msg)
507 if insitu is None: insitu = rcParams['insitu']
508 if not insitu:
509 scan = self.data.copy()
510 else:
511 scan = self.data
512 rows = xrange(scan.nrow())
513 from asap import asaplog
514 asaplog.push("Fitting:")
515 for r in rows:
516 out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),scan.getbeam(r),scan.getif(r),scan.getpol(r), scan.getcycle(r))
517 asaplog.push(out, False)
518 self.x = scan._getabcissa(r)
519 self.y = scan._getspectrum(r)
520 self.data = None
521 self.fit()
522 x = self.get_parameters()
523 scan._setspectrum(self.fitter.getresidual(), r)
524 print_log()
525 return scan
526
Note: See TracBrowser for help on using the repository browser.