source: trunk/python/asapfitter.py@ 659

Last change on this file since 659 was 652, checked in by mar637, 19 years ago

removed color loading as mpl now supports named colors. some minor corrections on pol label handling. Also added orientation option for ps output.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.0 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5 """
6 The fitting class for ASAP.
7 """
8 def _verbose(self, *args):
9 """
10 Set stdout output.
11 """
12 if type(args[0]) is bool:
13 self._vb = args[0]
14 return
15 elif len(args) == 0:
16 return self._vb
17
18 def __init__(self):
19 """
20 Create a fitter object. No state is set.
21 """
22 self.fitter = _asap.fitter()
23 self.x = None
24 self.y = None
25 self.mask = None
26 self.fitfunc = None
27 self.fitfuncs = None
28 self.fitted = False
29 self.data = None
30 self.components = 0
31 self._fittedrow = 0
32 self._p = None
33 self._vb = True
34 self._selection = None
35
36 def set_data(self, xdat, ydat, mask=None):
37 """
38 Set the absissa and ordinate for the fit. Also set the mask
39 indicationg valid points.
40 This can be used for data vectors retrieved from a scantable.
41 For scantable fitting use 'fitter.set_scan(scan, mask)'.
42 Parameters:
43 xdat: the abcissa values
44 ydat: the ordinate values
45 mask: an optional mask
46
47 """
48 self.fitted = False
49 self.x = xdat
50 self.y = ydat
51 if mask == None:
52 from numarray import ones
53 self.mask = ones(len(xdat))
54 else:
55 self.mask = mask
56 return
57
58 def set_scan(self, thescan=None, mask=None):
59 """
60 Set the 'data' (a scantable) of the fitter.
61 Parameters:
62 thescan: a scantable
63 mask: a msk retireved from the scantable
64 """
65 if not thescan:
66 print "Please give a correct scan"
67 self.fitted = False
68 self.data = thescan
69 if mask is None:
70 from numarray import ones
71 self.mask = ones(self.data.nchan())
72 else:
73 self.mask = mask
74 return
75
76 def set_function(self, **kwargs):
77 """
78 Set the function to be fit.
79 Parameters:
80 poly: use a polynomial of the order given
81 gauss: fit the number of gaussian specified
82 Example:
83 fitter.set_function(gauss=2) # will fit two gaussians
84 fitter.set_function(poly=3) # will fit a 3rd order polynomial
85 """
86 #default poly order 0
87 n=0
88 if kwargs.has_key('poly'):
89 self.fitfunc = 'poly'
90 n = kwargs.get('poly')
91 self.components = [n]
92 elif kwargs.has_key('gauss'):
93 n = kwargs.get('gauss')
94 self.fitfunc = 'gauss'
95 self.fitfuncs = [ 'gauss' for i in range(n) ]
96 self.components = [ 3 for i in range(n) ]
97 else:
98 print "Invalid function type."
99 return
100 self.fitter.setexpression(self.fitfunc,n)
101 return
102
103 def fit(self, row=0):
104 """
105 Execute the actual fitting process. All the state has to be set.
106 Parameters:
107 row: specify the row in the scantable
108 Example:
109 s = scantable('myscan.asap')
110 s.set_cursor(thepol=1) # select second pol
111 f = fitter()
112 f.set_scan(s)
113 f.set_function(poly=0)
114 f.fit(row=0) # fit first row
115 """
116 if ((self.x is None or self.y is None) and self.data is None) \
117 or self.fitfunc is None:
118 print "Fitter not yet initialised. Please set data & fit function"
119 return
120 else:
121 if self.data is not None:
122 self.x = self.data._getabcissa(row)
123 self.y = self.data._getspectrum(row)
124 print "Fitting:"
125 vb = self.data._vb
126 self.data._vb = True
127 self.selection = self.data.get_cursor()
128 self.data._vb = vb
129 self.fitter.setdata(self.x, self.y, self.mask)
130 if self.fitfunc == 'gauss':
131 ps = self.fitter.getparameters()
132 if len(ps) == 0:
133 self.fitter.estimate()
134 try:
135 self.fitter.fit()
136 except RuntimeError, msg:
137 print msg
138 self._fittedrow = row
139 self.fitted = True
140 return
141
142 def store_fit(self):
143 """
144 Store the fit parameters in the scantable.
145 """
146 if self.fitted and self.data is not None:
147 pars = list(self.fitter.getparameters())
148 fixed = list(self.fitter.getfixedparameters())
149 self.data._addfit(self._fittedrow, pars, fixed,
150 self.fitfuncs, self.components)
151
152 def set_parameters(self, params, fixed=None, component=None):
153 """
154 Set the parameters to be fitted.
155 Parameters:
156 params: a vector of parameters
157 fixed: a vector of which parameters are to be held fixed
158 (default is none)
159 component: in case of multiple gaussians, the index of the
160 component
161 """
162 if self.fitfunc is None:
163 print "Please specify a fitting function first."
164 return
165 if self.fitfunc == "gauss" and component is not None:
166 if not self.fitted:
167 from numarray import zeros
168 pars = list(zeros(len(self.components)*3))
169 fxd = list(zeros(len(pars)))
170 else:
171 pars = list(self.fitter.getparameters())
172 fxd = list(self.fitter.getfixedparameters())
173 i = 3*component
174 pars[i:i+3] = params
175 fxd[i:i+3] = fixed
176 params = pars
177 fixed = fxd
178 self.fitter.setparameters(params)
179 if fixed is not None:
180 self.fitter.setfixedparameters(fixed)
181 return
182
183 def set_gauss_parameters(self, peak, centre, fhwm,
184 peakfixed=False, centerfixed=False,
185 fhwmfixed=False,
186 component=0):
187 """
188 Set the Parameters of a 'Gaussian' component, set with set_function.
189 Parameters:
190 peak, centre, fhwm: The gaussian parameters
191 peakfixed,
192 centerfixed,
193 fhwmfixed: Optional parameters to indicate if
194 the paramters should be held fixed during
195 the fitting process. The default is to keep
196 all parameters flexible.
197 component: The number of the component (Default is the
198 component 0)
199 """
200 if self.fitfunc != "gauss":
201 print "Function only operates on Gaussian components."
202 return
203 if 0 <= component < len(self.components):
204 self.set_parameters([peak, centre, fhwm],
205 [peakfixed, centerfixed, fhwmfixed],
206 component)
207 else:
208 print "Please select a valid component."
209 return
210
211 def get_parameters(self, component=None):
212 """
213 Return the fit paramters.
214 Parameters:
215 component: get the parameters for the specified component
216 only, default is all components
217 """
218 if not self.fitted:
219 print "Not yet fitted."
220 pars = list(self.fitter.getparameters())
221 fixed = list(self.fitter.getfixedparameters())
222 if component is not None:
223 if self.fitfunc == "gauss":
224 i = 3*component
225 cpars = pars[i:i+3]
226 cfixed = fixed[i:i+3]
227 else:
228 cpars = pars
229 cfixed = fixed
230 else:
231 cpars = pars
232 cfixed = fixed
233 fpars = self._format_pars(cpars, cfixed)
234 if self._vb:
235 print fpars
236 return cpars, cfixed, fpars
237
238 def _format_pars(self, pars, fixed):
239 out = ''
240 if self.fitfunc == 'poly':
241 c = 0
242 for i in range(len(pars)):
243 fix = ""
244 if fixed[i]: fix = "(fixed)"
245 out += ' p%d%s= %3.3f,' % (c,fix,pars[i])
246 c+=1
247 out = out[:-1] # remove trailing ','
248 elif self.fitfunc == 'gauss':
249 i = 0
250 c = 0
251 aunit = ''
252 ounit = ''
253 if self.data:
254 aunit = self.data.get_unit()
255 ounit = self.data.get_fluxunit()
256 while i < len(pars):
257 out += ' %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
258 c+=1
259 i+=3
260 return out
261
262 def get_estimate(self):
263 """
264 Return the parameter estimates (for non-linear functions).
265 """
266 pars = self.fitter.getestimate()
267 if self._vb:
268 print self._format_pars(pars)
269 return pars
270
271
272 def get_residual(self):
273 """
274 Return the residual of the fit.
275 """
276 if not self.fitted:
277 print "Not yet fitted."
278 return self.fitter.getresidual()
279
280 def get_chi2(self):
281 """
282 Return chi^2.
283 """
284 if not self.fitted:
285 print "Not yet fitted."
286 ch2 = self.fitter.getchi2()
287 if self._vb:
288 print 'Chi^2 = %3.3f' % (ch2)
289 return ch2
290
291 def get_fit(self):
292 """
293 Return the fitted ordinate values.
294 """
295 if not self.fitted:
296 print "Not yet fitted."
297 return self.fitter.getfit()
298
299 def commit(self):
300 """
301 Return a new scan where the fits have been commited (subtracted)
302 """
303 if not self.fitted:
304 print "Not yet fitted."
305 if self.data is not scantable:
306 print "Only works with scantables"
307 return
308 scan = self.data.copy()
309 scan._setspectrum(self.fitter.getresidual())
310
311 def plot(self, residual=False, components=None, plotparms=False):
312 """
313 Plot the last fit.
314 Parameters:
315 residual: an optional parameter indicating if the residual
316 should be plotted (default 'False')
317 components: a list of components to plot, e.g [0,1],
318 -1 plots the total fit. Default is to only
319 plot the total fit.
320 plotparms: Inidicates if the parameter values should be present
321 on the plot
322 """
323 if not self.fitted:
324 return
325 if not self._p:
326 from asap.asaplot import ASAPlot
327 self._p = ASAPlot()
328 if self._p.is_dead:
329 from asap.asaplot import ASAPlot
330 self._p = ASAPlot()
331 self._p.clear()
332 self._p.set_panels()
333 self._p.palette(0)
334 tlab = 'Spectrum'
335 xlab = 'Abcissa'
336 m = ()
337 if self.data:
338 tlab = self.data._getsourcename(self._fittedrow)
339 xlab = self.data._getabcissalabel(self._fittedrow)
340 m = self.data._getmask(self._fittedrow)
341 ylab = self.data._get_ordinate_label()
342
343 colours = ["grey60","grey80","red","orange","purple","green","magenta", "cyan"]
344 self._p.palette(0,colours)
345 self._p.set_line(label='Spectrum')
346 self._p.plot(self.x, self.y, m)
347 if residual:
348 self._p.palette(1)
349 self._p.set_line(label='Residual')
350 self._p.plot(self.x, self.get_residual(), m)
351 self._p.palette(2)
352 if components is not None:
353 cs = components
354 if isinstance(components,int): cs = [components]
355 if plotparms:
356 self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
357 n = len(self.components)
358 self._p.palette(3)
359 for c in cs:
360 if 0 <= c < n:
361 lab = self.fitfuncs[c]+str(c)
362 self._p.set_line(label=lab)
363 self._p.plot(self.x, self.fitter.evaluate(c), m)
364 elif c == -1:
365 self._p.palette(2)
366 self._p.set_line(label="Total Fit")
367 self._p.plot(self.x, self.get_fit(), m)
368 else:
369 self._p.palette(2)
370 self._p.set_line(label='Fit')
371 self._p.plot(self.x, self.get_fit(), m)
372 self._p.set_axes('xlabel',xlab)
373 self._p.set_axes('ylabel',ylab)
374 self._p.set_axes('title',tlab)
375 self._p.release()
376
377 def auto_fit(self, insitu=None):
378 """
379 Return a scan where the function is applied to all rows for
380 all Beams/IFs/Pols.
381
382 """
383 from asap import scantable
384 if not isinstance(self.data, scantable) :
385 print "Only works with scantables"
386 return
387 if insitu is None: insitu = rcParams['insitu']
388 if not insitu:
389 scan = self.data.copy()
390 else:
391 scan = self.data
392 vb = scan._vb
393 scan._vb = False
394 sel = scan.get_cursor()
395 rows = range(scan.nrow())
396 for i in range(scan.nbeam()):
397 scan.setbeam(i)
398 for j in range(scan.nif()):
399 scan.setif(j)
400 for k in range(scan.npol()):
401 scan.setpol(k)
402 if self._vb:
403 print "Fitting:"
404 print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
405 for iRow in rows:
406 self.x = scan._getabcissa(iRow)
407 self.y = scan._getspectrum(iRow)
408 self.data = None
409 self.fit()
410 x = self.get_parameters()
411 scan._setspectrum(self.fitter.getresidual(),iRow)
412 scan.set_cursor(sel[0],sel[1],sel[2])
413 scan._vb = vb
414 if not insitu:
415 return scan
Note: See TracBrowser for help on using the repository browser.