source: branches/newfiller/python/asapfitter.py@ 2316

Last change on this file since 2316 was 1798, checked in by Malte Marquarding, 14 years ago

merge -r1774:1797 from alma to newfiller

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 26.1 KB
Line 
1import _asap
2from asap import rcParams
3from asap import print_log, print_log_dec
4from asap import _n_bools
5from asap import mask_and
6from asap import asaplog
7
8class fitter:
9 """
10 The fitting class for ASAP.
11 """
12
13 def __init__(self):
14 """
15 Create a fitter object. No state is set.
16 """
17 self.fitter = _asap.fitter()
18 self.x = None
19 self.y = None
20 self.mask = None
21 self.fitfunc = None
22 self.fitfuncs = None
23 self.fitted = False
24 self.data = None
25 self.components = 0
26 self._fittedrow = 0
27 self._p = None
28 self._selection = None
29 self.uselinear = False
30
31 def set_data(self, xdat, ydat, mask=None):
32 """
33 Set the absissa and ordinate for the fit. Also set the mask
34 indicationg valid points.
35 This can be used for data vectors retrieved from a scantable.
36 For scantable fitting use 'fitter.set_scan(scan, mask)'.
37 Parameters:
38 xdat: the abcissa values
39 ydat: the ordinate values
40 mask: an optional mask
41
42 """
43 self.fitted = False
44 self.x = xdat
45 self.y = ydat
46 if mask == None:
47 self.mask = _n_bools(len(xdat), True)
48 else:
49 self.mask = mask
50 return
51
52 def set_scan(self, thescan=None, mask=None):
53 """
54 Set the 'data' (a scantable) of the fitter.
55 Parameters:
56 thescan: a scantable
57 mask: a msk retrieved from the scantable
58 """
59 if not thescan:
60 msg = "Please give a correct scan"
61 if rcParams['verbose']:
62 #print msg
63 asaplog.push(msg)
64 print_log('ERROR')
65 return
66 else:
67 raise TypeError(msg)
68 self.fitted = False
69 self.data = thescan
70 self.mask = None
71 if mask is None:
72 self.mask = _n_bools(self.data.nchan(), True)
73 else:
74 self.mask = mask
75 return
76
77 def set_function(self, **kwargs):
78 """
79 Set the function to be fit.
80 Parameters:
81 poly: use a polynomial of the order given with nonlinear least squares fit
82 lpoly: use polynomial of the order given with linear least squares fit
83 gauss: fit the number of gaussian specified
84 lorentz: fit the number of lorentzian specified
85 Example:
86 fitter.set_function(poly=3) # will fit a 3rd order polynomial via nonlinear method
87 fitter.set_function(lpoly=3) # will fit a 3rd order polynomial via linear method
88 fitter.set_function(gauss=2) # will fit two gaussians
89 fitter.set_function(lorentz=2) # will fit two lorentzians
90 """
91 #default poly order 0
92 n=0
93 if kwargs.has_key('poly'):
94 self.fitfunc = 'poly'
95 n = kwargs.get('poly')
96 self.components = [n]
97 self.uselinear = False
98 elif kwargs.has_key('lpoly'):
99 self.fitfunc = 'poly'
100 n = kwargs.get('lpoly')
101 self.components = [n]
102 self.uselinear = True
103 elif kwargs.has_key('gauss'):
104 n = kwargs.get('gauss')
105 self.fitfunc = 'gauss'
106 self.fitfuncs = [ 'gauss' for i in range(n) ]
107 self.components = [ 3 for i in range(n) ]
108 self.uselinear = False
109 elif kwargs.has_key('lorentz'):
110 n = kwargs.get('lorentz')
111 self.fitfunc = 'lorentz'
112 self.fitfuncs = [ 'lorentz' for i in range(n) ]
113 self.components = [ 3 for i in range(n) ]
114 self.uselinear = False
115 else:
116 msg = "Invalid function type."
117 if rcParams['verbose']:
118 #print msg
119 asaplog.push(msg)
120 print_log('ERROR')
121 return
122 else:
123 raise TypeError(msg)
124
125 self.fitter.setexpression(self.fitfunc,n)
126 self.fitted = False
127 return
128
129 @print_log_dec
130 def fit(self, row=0, estimate=False):
131 """
132 Execute the actual fitting process. All the state has to be set.
133 Parameters:
134 row: specify the row in the scantable
135 estimate: auto-compute an initial parameter set (default False)
136 This can be used to compute estimates even if fit was
137 called before.
138 Example:
139 s = scantable('myscan.asap')
140 s.set_cursor(thepol=1) # select second pol
141 f = fitter()
142 f.set_scan(s)
143 f.set_function(poly=0)
144 f.fit(row=0) # fit first row
145 """
146 if ((self.x is None or self.y is None) and self.data is None) \
147 or self.fitfunc is None:
148 msg = "Fitter not yet initialised. Please set data & fit function"
149 if rcParams['verbose']:
150 #print msg
151 asaplog.push(msg)
152 print_log('ERROR')
153 return
154 else:
155 raise RuntimeError(msg)
156
157 else:
158 if self.data is not None:
159 self.x = self.data._getabcissa(row)
160 self.y = self.data._getspectrum(row)
161 self.mask = mask_and(self.mask, self.data._getmask(row))
162 asaplog.push("Fitting:")
163 i = row
164 out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),
165 self.data.getbeam(i),
166 self.data.getif(i),
167 self.data.getpol(i),
168 self.data.getcycle(i))
169 asaplog.push(out,False)
170 self.fitter.setdata(self.x, self.y, self.mask)
171 if self.fitfunc == 'gauss' or self.fitfunc == 'lorentz':
172 ps = self.fitter.getparameters()
173 if len(ps) == 0 or estimate:
174 self.fitter.estimate()
175 try:
176 fxdpar = list(self.fitter.getfixedparameters())
177 if len(fxdpar) and fxdpar.count(0) == 0:
178 raise RuntimeError,"No point fitting, if all parameters are fixed."
179 if self.uselinear:
180 converged = self.fitter.lfit()
181 else:
182 converged = self.fitter.fit()
183 if not converged:
184 raise RuntimeError,"Fit didn't converge."
185 except RuntimeError, msg:
186 if rcParams['verbose']:
187 #print msg
188 print_log()
189 asaplog.push(str(msg))
190 print_log('ERROR')
191 else:
192 raise
193 self._fittedrow = row
194 self.fitted = True
195 print_log()
196 return
197
198 def store_fit(self, filename=None):
199 """
200 Save the fit parameters.
201 Parameters:
202 filename: if specified save as an ASCII file, if None (default)
203 store it in the scnatable
204 """
205 if self.fitted and self.data is not None:
206 pars = list(self.fitter.getparameters())
207 fixed = list(self.fitter.getfixedparameters())
208 from asap.asapfit import asapfit
209 fit = asapfit()
210 fit.setparameters(pars)
211 fit.setfixedparameters(fixed)
212 fit.setfunctions(self.fitfuncs)
213 fit.setcomponents(self.components)
214 fit.setframeinfo(self.data._getcoordinfo())
215 if filename is not None:
216 import os
217 filename = os.path.expandvars(os.path.expanduser(filename))
218 if os.path.exists(filename):
219 raise IOError("File '%s' exists." % filename)
220 fit.save(filename)
221 else:
222 self.data._addfit(fit,self._fittedrow)
223
224 @print_log_dec
225 def set_parameters(self,*args,**kwargs):
226 """
227 Set the parameters to be fitted.
228 Parameters:
229 params: a vector of parameters
230 fixed: a vector of which parameters are to be held fixed
231 (default is none)
232 component: in case of multiple gaussians, the index of the
233 component
234 """
235 component = None
236 fixed = None
237 params = None
238
239 if len(args) and isinstance(args[0],dict):
240 kwargs = args[0]
241 if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
242 if kwargs.has_key("params"): params = kwargs["params"]
243 if len(args) == 2 and isinstance(args[1], int):
244 component = args[1]
245 if self.fitfunc is None:
246 msg = "Please specify a fitting function first."
247 if rcParams['verbose']:
248 #print msg
249 asaplog.push(msg)
250 print_log('ERROR')
251 return
252 else:
253 raise RuntimeError(msg)
254 if (self.fitfunc == "gauss" or self.fitfunc == 'lorentz') and component is not None:
255 if not self.fitted and sum(self.fitter.getparameters()) == 0:
256 pars = _n_bools(len(self.components)*3, False)
257 fxd = _n_bools(len(pars), False)
258 else:
259 pars = list(self.fitter.getparameters())
260 fxd = list(self.fitter.getfixedparameters())
261 i = 3*component
262 pars[i:i+3] = params
263 fxd[i:i+3] = fixed
264 params = pars
265 fixed = fxd
266 self.fitter.setparameters(params)
267 if fixed is not None:
268 self.fitter.setfixedparameters(fixed)
269 print_log()
270 return
271
272 def set_gauss_parameters(self, peak, centre, fwhm,
273 peakfixed=0, centrefixed=0,
274 fwhmfixed=0,
275 component=0):
276 """
277 Set the Parameters of a 'Gaussian' component, set with set_function.
278 Parameters:
279 peak, centre, fwhm: The gaussian parameters
280 peakfixed,
281 centrefixed,
282 fwhmfixed: Optional parameters to indicate if
283 the paramters should be held fixed during
284 the fitting process. The default is to keep
285 all parameters flexible.
286 component: The number of the component (Default is the
287 component 0)
288 """
289 if self.fitfunc != "gauss":
290 msg = "Function only operates on Gaussian components."
291 if rcParams['verbose']:
292 #print msg
293 asaplog.push(msg)
294 print_log('ERROR')
295 return
296 else:
297 raise ValueError(msg)
298 if 0 <= component < len(self.components):
299 d = {'params':[peak, centre, fwhm],
300 'fixed':[peakfixed, centrefixed, fwhmfixed]}
301 self.set_parameters(d, component)
302 else:
303 msg = "Please select a valid component."
304 if rcParams['verbose']:
305 #print msg
306 asaplog.push(msg)
307 print_log('ERROR')
308 return
309 else:
310 raise ValueError(msg)
311
312 def set_lorentz_parameters(self, peak, centre, fwhm,
313 peakfixed=0, centrefixed=0,
314 fwhmfixed=0,
315 component=0):
316 """
317 Set the Parameters of a 'Lorentzian' component, set with set_function.
318 Parameters:
319 peak, centre, fwhm: The gaussian parameters
320 peakfixed,
321 centrefixed,
322 fwhmfixed: Optional parameters to indicate if
323 the paramters should be held fixed during
324 the fitting process. The default is to keep
325 all parameters flexible.
326 component: The number of the component (Default is the
327 component 0)
328 """
329 if self.fitfunc != "lorentz":
330 msg = "Function only operates on Lorentzian components."
331 if rcParams['verbose']:
332 #print msg
333 asaplog.push(msg)
334 print_log('ERROR')
335 return
336 else:
337 raise ValueError(msg)
338 if 0 <= component < len(self.components):
339 d = {'params':[peak, centre, fwhm],
340 'fixed':[peakfixed, centrefixed, fwhmfixed]}
341 self.set_parameters(d, component)
342 else:
343 msg = "Please select a valid component."
344 if rcParams['verbose']:
345 #print msg
346 asaplog.push(msg)
347 print_log('ERROR')
348 return
349 else:
350 raise ValueError(msg)
351
352 def get_area(self, component=None):
353 """
354 Return the area under the fitted gaussian/lorentzian component.
355 Parameters:
356 component: the gaussian/lorentzian component selection,
357 default (None) is the sum of all components
358 Note:
359 This will only work for gaussian/lorentzian fits.
360 """
361 if not self.fitted: return
362 if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
363 pars = list(self.fitter.getparameters())
364 from math import log,pi,sqrt
365 if self.fitfunc == "gauss":
366 fac = sqrt(pi/log(16.0))
367 elif self.fitfunc == "lorentz":
368 fac = pi/2.0
369 areas = []
370 for i in range(len(self.components)):
371 j = i*3
372 cpars = pars[j:j+3]
373 areas.append(fac * cpars[0] * cpars[2])
374 else:
375 return None
376 if component is not None:
377 return areas[component]
378 else:
379 return sum(areas)
380
381 def get_errors(self, component=None):
382 """
383 Return the errors in the parameters.
384 Parameters:
385 component: get the errors for the specified component
386 only, default is all components
387 """
388 if not self.fitted:
389 msg = "Not yet fitted."
390 if rcParams['verbose']:
391 #print msg
392 asaplog.push(msg)
393 print_log('ERROR')
394 return
395 else:
396 raise RuntimeError(msg)
397 errs = list(self.fitter.geterrors())
398 cerrs = errs
399 if component is not None:
400 if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
401 i = 3*component
402 if i < len(errs):
403 cerrs = errs[i:i+3]
404 return cerrs
405
406 def get_parameters(self, component=None, errors=False):
407 """
408 Return the fit paramters.
409 Parameters:
410 component: get the parameters for the specified component
411 only, default is all components
412 """
413 if not self.fitted:
414 msg = "Not yet fitted."
415 if rcParams['verbose']:
416 #print msg
417 asaplog.push(msg)
418 print_log('ERROR')
419 return
420 else:
421 raise RuntimeError(msg)
422 pars = list(self.fitter.getparameters())
423 fixed = list(self.fitter.getfixedparameters())
424 errs = list(self.fitter.geterrors())
425 area = []
426 if component is not None:
427 if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
428 i = 3*component
429 cpars = pars[i:i+3]
430 cfixed = fixed[i:i+3]
431 cerrs = errs[i:i+3]
432 a = self.get_area(component)
433 area = [a for i in range(3)]
434 else:
435 cpars = pars
436 cfixed = fixed
437 cerrs = errs
438 else:
439 cpars = pars
440 cfixed = fixed
441 cerrs = errs
442 if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
443 for c in range(len(self.components)):
444 a = self.get_area(c)
445 area += [a for i in range(3)]
446 fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
447 if rcParams['verbose']:
448 #print fpars
449 asaplog.push(fpars)
450 print_log()
451 return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
452 'errors':cerrs}
453
454 def _format_pars(self, pars, fixed, errors, area):
455 out = ''
456 if self.fitfunc == 'poly':
457 c = 0
458 for i in range(len(pars)):
459 fix = ""
460 if len(fixed) and fixed[i]: fix = "(fixed)"
461 if errors :
462 out += ' p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
463 else:
464 out += ' p%d%s= %3.6f,' % (c,fix,pars[i])
465 c+=1
466 out = out[:-1] # remove trailing ','
467 elif self.fitfunc == 'gauss' or self.fitfunc == 'lorentz':
468 i = 0
469 c = 0
470 aunit = ''
471 ounit = ''
472 if self.data:
473 aunit = self.data.get_unit()
474 ounit = self.data.get_fluxunit()
475 while i < len(pars):
476 if len(area):
477 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
478 else:
479 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
480 c+=1
481 i+=3
482 return out
483
484 def get_estimate(self):
485 """
486 Return the parameter estimates (for non-linear functions).
487 """
488 pars = self.fitter.getestimate()
489 fixed = self.fitter.getfixedparameters()
490 if rcParams['verbose']:
491 #print self._format_pars(pars,fixed,None)
492 asaplog.push(self._format_pars(pars,fixed,None))
493 print_log()
494 return pars
495
496 def get_residual(self):
497 """
498 Return the residual of the fit.
499 """
500 if not self.fitted:
501 msg = "Not yet fitted."
502 if rcParams['verbose']:
503 #print msg
504 asaplog.push(msg)
505 print_log('ERROR')
506 return
507 else:
508 raise RuntimeError(msg)
509 return self.fitter.getresidual()
510
511 def get_chi2(self):
512 """
513 Return chi^2.
514 """
515 if not self.fitted:
516 msg = "Not yet fitted."
517 if rcParams['verbose']:
518 #print msg
519 asaplog.push(msg)
520 print_log('ERROR')
521 return
522 else:
523 raise RuntimeError(msg)
524 ch2 = self.fitter.getchi2()
525 if rcParams['verbose']:
526 #print 'Chi^2 = %3.3f' % (ch2)
527 asaplog.push( 'Chi^2 = %3.3f' % (ch2) )
528 print_log()
529 return ch2
530
531 def get_fit(self):
532 """
533 Return the fitted ordinate values.
534 """
535 if not self.fitted:
536 msg = "Not yet fitted."
537 if rcParams['verbose']:
538 #print msg
539 asaplog.push(msg)
540 print_log('ERROR')
541 return
542 else:
543 raise RuntimeError(msg)
544 return self.fitter.getfit()
545
546 @print_log_dec
547 def commit(self):
548 """
549 Return a new scan where the fits have been commited (subtracted)
550 """
551 if not self.fitted:
552 msg = "Not yet fitted."
553 if rcParams['verbose']:
554 #print msg
555 asaplog.push(msg)
556 print_log('ERROR')
557 return
558 else:
559 raise RuntimeError(msg)
560 from asap import scantable
561 if not isinstance(self.data, scantable):
562 msg = "Not a scantable"
563 if rcParams['verbose']:
564 #print msg
565 asaplog.push(msg)
566 print_log('ERROR')
567 return
568 else:
569 raise TypeError(msg)
570 scan = self.data.copy()
571 scan._setspectrum(self.fitter.getresidual())
572 print_log()
573 return scan
574
575 @print_log_dec
576 def plot(self, residual=False, components=None, plotparms=False,
577 filename=None):
578 """
579 Plot the last fit.
580 Parameters:
581 residual: an optional parameter indicating if the residual
582 should be plotted (default 'False')
583 components: a list of components to plot, e.g [0,1],
584 -1 plots the total fit. Default is to only
585 plot the total fit.
586 plotparms: Inidicates if the parameter values should be present
587 on the plot
588 """
589 if not self.fitted:
590 return
591 if not self._p or self._p.is_dead:
592 if rcParams['plotter.gui']:
593 from asap.asaplotgui import asaplotgui as asaplot
594 else:
595 from asap.asaplot import asaplot
596 self._p = asaplot()
597 self._p.hold()
598 self._p.clear()
599 self._p.set_panels()
600 self._p.palette(0)
601 tlab = 'Spectrum'
602 xlab = 'Abcissa'
603 ylab = 'Ordinate'
604 from numpy import ma,logical_not,logical_and,array
605 m = self.mask
606 if self.data:
607 tlab = self.data._getsourcename(self._fittedrow)
608 xlab = self.data._getabcissalabel(self._fittedrow)
609 m = logical_and(self.mask,
610 array(self.data._getmask(self._fittedrow),
611 copy=False))
612
613 ylab = self.data._get_ordinate_label()
614
615 colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
616 nomask=True
617 for i in range(len(m)):
618 nomask = nomask and m[i]
619 label0='Masked Region'
620 label1='Spectrum'
621 if ( nomask ):
622 label0=label1
623 else:
624 y = ma.masked_array( self.y, mask = m )
625 self._p.palette(1,colours)
626 self._p.set_line( label = label1 )
627 self._p.plot( self.x, y )
628 self._p.palette(0,colours)
629 self._p.set_line(label=label0)
630 y = ma.masked_array(self.y,mask=logical_not(m))
631 self._p.plot(self.x, y)
632 if residual:
633 self._p.palette(7)
634 self._p.set_line(label='Residual')
635 y = ma.masked_array(self.get_residual(),
636 mask=logical_not(m))
637 self._p.plot(self.x, y)
638 self._p.palette(2)
639 if components is not None:
640 cs = components
641 if isinstance(components,int): cs = [components]
642 if plotparms:
643 self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
644 n = len(self.components)
645 self._p.palette(3)
646 for c in cs:
647 if 0 <= c < n:
648 lab = self.fitfuncs[c]+str(c)
649 self._p.set_line(label=lab)
650 y = ma.masked_array(self.fitter.evaluate(c),
651 mask=logical_not(m))
652
653 self._p.plot(self.x, y)
654 elif c == -1:
655 self._p.palette(2)
656 self._p.set_line(label="Total Fit")
657 y = ma.masked_array(self.fitter.getfit(),
658 mask=logical_not(m))
659 self._p.plot(self.x, y)
660 else:
661 self._p.palette(2)
662 self._p.set_line(label='Fit')
663 y = ma.masked_array(self.fitter.getfit(),
664 mask=logical_not(m))
665 self._p.plot(self.x, y)
666 xlim=[min(self.x),max(self.x)]
667 self._p.axes.set_xlim(xlim)
668 self._p.set_axes('xlabel',xlab)
669 self._p.set_axes('ylabel',ylab)
670 self._p.set_axes('title',tlab)
671 self._p.release()
672 if (not rcParams['plotter.gui']):
673 self._p.save(filename)
674 print_log()
675
676 @print_log_dec
677 def auto_fit(self, insitu=None, plot=False):
678 """
679 Return a scan where the function is applied to all rows for
680 all Beams/IFs/Pols.
681
682 """
683 from asap import scantable
684 if not isinstance(self.data, scantable) :
685 msg = "Data is not a scantable"
686 if rcParams['verbose']:
687 #print msg
688 asaplog.push(msg)
689 print_log('ERROR')
690 return
691 else:
692 raise TypeError(msg)
693 if insitu is None: insitu = rcParams['insitu']
694 if not insitu:
695 scan = self.data.copy()
696 else:
697 scan = self.data
698 rows = xrange(scan.nrow())
699 # Save parameters of baseline fits as a class attribute.
700 # NOTICE: This does not reflect changes in scantable!
701 if len(rows) > 0: self.blpars=[]
702 asaplog.push("Fitting:")
703 for r in rows:
704 out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),
705 scan.getbeam(r),
706 scan.getif(r),
707 scan.getpol(r),
708 scan.getcycle(r))
709 asaplog.push(out, False)
710 self.x = scan._getabcissa(r)
711 self.y = scan._getspectrum(r)
712 self.mask = mask_and(self.mask, scan._getmask(r))
713 self.data = None
714 self.fit()
715 x = self.get_parameters()
716 fpar = self.get_parameters()
717 if plot:
718 self.plot(residual=True)
719 x = raw_input("Accept fit ([y]/n): ")
720 if x.upper() == 'N':
721 self.blpars.append(None)
722 continue
723 scan._setspectrum(self.fitter.getresidual(), r)
724 self.blpars.append(fpar)
725 if plot:
726 self._p.unmap()
727 self._p = None
728 print_log()
729 return scan
Note: See TracBrowser for help on using the repository browser.