Changeset 515 for trunk


Ignore:
Timestamp:
02/28/05 15:32:29 (19 years ago)
Author:
mar637
Message:
  • major rework on plotting.
  • added component selection and plotting
  • added wrapper function to set parameters
  • addedd formatting of parameter print out
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/python/asapfitter.py

    r302 r515  
    2525        self.mask = None
    2626        self.fitfunc = None
     27        self.fitfuncs = None
    2728        self.fitted = False
    2829        self.data = None
     30        self.components = 0
     31        self._fittedrow = 0
    2932        self._p = None
    3033        self._vb = True
     34        self._selection = None
    3135
    3236    def set_data(self, xdat, ydat, mask=None):
     
    8084            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
    8185        """
    82         #default poly order 0
    83        
    84 
     86        #default poly order 0       
     87        n=0
    8588        if kwargs.has_key('poly'):
    8689            self.fitfunc = 'poly'
    8790            n = kwargs.get('poly')
     91            self.components = [n]
    8892        elif kwargs.has_key('gauss'):
    8993            n = kwargs.get('gauss')
    9094            self.fitfunc = 'gauss'
    91        
     95            self.fitfuncs = [ 'gauss' for i in range(n) ]
     96            self.components = [ 3 for i in range(n) ]
     97        else:
     98            print "Invalid function type."
     99            return
    92100        self.fitter.setexpression(self.fitfunc,n)
    93101        return
    94102           
    95     def fit(self):
     103    def fit(self, row=0):
    96104        """
    97105        Execute the actual fitting process. All the state has to be set.
     
    99107            none
    100108        Example:
    101             s= scantable('myscan.asap')
     109            s = scantable('myscan.asap')
     110            s.set_cursor(thepol=1)        # select second pol
    102111            f = fitter()
    103112            f.set_scan(s)
    104113            f.set_function(poly=0)
    105             f.fit()
     114            f.fit(row=0)                  # fit first row
    106115        """
    107116        if ((self.x is None or self.y is None) and self.data is None) \
     
    111120        else:
    112121            if self.data is not None:
    113                 self.x = self.data._getabcissa()
    114                 self.y = self.data._getspectrum()
     122                self.x = self.data._getabcissa(row)
     123                self.y = self.data._getspectrum(row)
    115124                print "Fitting:"
    116125                vb = self.data._vb
    117126                self.data._vb = True
    118                 s = self.data.get_cursor()
     127                self.selection = self.data.get_cursor()
    119128                self.data._vb = vb
    120        
    121         self.fitter.setdata(self.x,self.y,self.mask)
     129        self.fitter.setdata(self.x, self.y, self.mask)
    122130        if self.fitfunc == 'gauss':
    123131            ps = self.fitter.getparameters()
     
    125133                self.fitter.estimate()
    126134        self.fitter.fit()
     135        self._fittedrow = row
    127136        self.fitted = True
    128137        return
    129138
    130     def set_parameters(self, params, fixed=None):
     139    def store_fit(self):
     140        if self.fitted and self.data is not None:
     141            pars = list(self.fitter.getparameters())
     142            fixed = list(self.fitter.getfixedparameters())
     143            self.data._addfit(self._fittedrow, pars, fixed,
     144                              self.fitfuncs, self.components)
     145
     146    def set_parameters(self, params, fixed=None, component=None):
     147        if self.fitfunc is None:
     148            print "Please specify a fitting function first."
     149            return
     150        if self.fitfunc == "gauss" and component is not None:
     151            if not self.fitted:
     152                from numarray import zeros
     153                pars = list(zeros(len(self.components)*3))
     154                fxd = list(zeros(len(pars)))
     155            else:
     156                pars = list(self.fitter.getparameters())             
     157                fxd = list(self.fitter.getfixedparameters())
     158            i = 3*component
     159            pars[i:i+3] = params
     160            fxd[i:i+3] = fixed
     161            params = pars
     162            fixed = fxd         
    131163        self.fitter.setparameters(params)
    132164        if fixed is not None:
    133165            self.fitter.setfixedparameters(fixed)
    134166        return
    135    
    136     def get_parameters(self):
     167
     168    def set_gauss_parameters(self, peak, centre, fhwm,
     169                             peakfixed=False, centerfixed=False,
     170                             fhwmfixed=False,
     171                             component=0):
     172        """
     173        Set the Parameters of a 'Gaussian' component, set with set_function.
     174        Parameters:
     175            component:           The number of the component (Default is the
     176                                 first component.
     177            peak, centre, fhwm:  The gaussian parameters
     178            peakfixed,
     179            centerfixed,
     180            fhwmfixed:           Optional parameters to indicate if
     181                                 the paramters should be held fixed during
     182                                 the fitting process. The default is to keep
     183                                 all parameters flexible.
     184        """
     185        if self.fitfunc != "gauss":
     186            print "Function only operates on Gaussian components."
     187            return
     188        if 0 <= component < len(self.components):
     189            self.set_parameters([peak, centre, fhwm],
     190                                [peakfixed, centerfixed, fhwmfixed],
     191                                component)
     192        else:
     193            print "Please select a valid  component."
     194            return
     195       
     196    def get_parameters(self, component=None):
    137197        """
    138198        Return the fit paramters.
     
    143203        pars = list(self.fitter.getparameters())
    144204        fixed = list(self.fitter.getfixedparameters())
     205        if component is not None:           
     206            if self.fitfunc == "gauss":
     207                i = 3*component
     208                cpars = pars[i:i+3]
     209                cfixed = fixed[i:i+3]
     210            else:
     211                cpars = pars
     212                cfixed = fixed               
     213        else:
     214            cpars = pars
     215            cfixed = fixed
     216        fpars = self._format_pars(cpars, cfixed)
    145217        if self._vb:
    146             print self._format_pars(pars)
    147         return pars,fixed
     218            print fpars
     219        return cpars, cfixed, fpars
    148220   
    149     def _format_pars(self, pars):
     221    def _format_pars(self, pars, fixed):
    150222        out = ''
    151223        if self.fitfunc == 'poly':
    152224            c = 0
    153             for i in pars:
    154                 out += '  p%d = %3.3f, ' % (c,i)
     225            for i in range(len(pars)):
     226                fix = ""
     227                if fixed[i]: fix = "(fixed)"
     228                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
    155229                c+=1
     230            out = out[:-1]  # remove trailing ','
    156231        elif self.fitfunc == 'gauss':
    157232            i = 0
    158233            c = 0
    159             unit = ''
     234            aunit = ''
     235            ounit = ''
    160236            if self.data:
    161                 unit = self.data.get_unit()
     237                aunit = self.data.get_unit()
     238                ounit = self.data.get_fluxunit()
    162239            while i < len(pars):
    163                 out += '  %d: peak = %3.3f , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],pars[i+1],unit,pars[i+2],unit)
     240                out += '  %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
    164241                c+=1
    165242                i+=3
     
    168245    def get_estimate(self):
    169246        """
    170         Return the paramter estimates (for non-linear functions).
     247        Return the parameter estimates (for non-linear functions).
    171248        """
    172249        pars = self.fitter.getestimate()
     
    188265        Return chi^2.
    189266        """
    190        
    191267        if not self.fitted:
    192268            print "Not yet fitted."
     
    206282    def commit(self):
    207283        """
    208         Return a new scan where teh fits have been commited.
     284        Return a new scan where the fits have been commited.
    209285        """
    210286        if not self.fitted:
     
    216292        scan._setspectrum(self.fitter.getresidual())
    217293
    218     def plot(self, residual=False):
     294    def plot(self, residual=False, components=None, plotparms=False,
     295             plotrange=None):
    219296        """
    220297        Plot the last fit.
     
    232309            self._p = ASAPlot()
    233310        self._p.clear()
     311        self._p.set_panels()
     312        self._p.palette(1)
    234313        tlab = 'Spectrum'
    235         xlab = 'Abcissa'
     314        xlab = 'Abcissa'       
     315        m = ()
    236316        if self.data:
    237             tlab = self.data._getsourcename(0)
    238             xlab = self.data._getabcissalabel(0)
    239         ylab = r'Flux'
    240         m = self.data._getmask(0)
    241         self._p.set_line(colour='blue',label='Spectrum')
     317            tlab = self.data._getsourcename(self._fittedrow)
     318            xlab = self.data._getabcissalabel(self._fittedrow)
     319            m = self.data._getmask(self._fittedrow)
     320            ylab = r'Flux'
     321
     322        colours = ["grey60","grey80","red","orange","purple","yellow","magenta", "cyan"]
     323        self._p.palette(1,colours)
     324        self._p.set_line(label='Spectrum')
    242325        self._p.plot(self.x, self.y, m)
    243326        if residual:
    244             self._p.set_line(colour='green',label='Residual')
     327            self._p.palette(2)
     328            self._p.set_line(label='Residual')
    245329            self._p.plot(self.x, self.get_residual(), m)
    246         self._p.set_line(colour='red',label='Fit')
    247         self._p.plot(self.x, self.get_fit(), m)
    248        
     330        self._p.palette(3)
     331        if components is not None:
     332            cs = components
     333            if isinstance(components,int): cs = [components]
     334            self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
     335            n = len(self.components)
     336            self._p.palette(4)
     337            for c in cs:
     338                if 0 <= c < n:
     339                    lab = self.fitfuncs[c]+str(c)
     340                    self._p.set_line(label=lab)
     341                    self._p.plot(self.x, self.fitter.evaluate(c), m)
     342                elif c == -1:
     343                    self._p.palette(3)
     344                    self._p.set_line(label="Total Fit")
     345                    self._p.plot(self.x, self.get_fit(), m)                   
     346        else:
     347            self._p.palette(3)
     348            self._p.set_line(label='Fit')
     349            self._p.plot(self.x, self.get_fit(), m)
    249350        self._p.set_axes('xlabel',xlab)
    250351        self._p.set_axes('ylabel',ylab)
     
    252353        self._p.release()
    253354
    254 
    255355    def auto_fit(self, insitu=None):
    256356        """
    257         Return a scan where the function is applied to all rows for all Beams/IFs/Pols.
     357        Return a scan where the function is applied to all rows for
     358        all Beams/IFs/Pols.
    258359       
    259360        """
    260361        from asap import scantable
    261         if not isinstance(self.data,scantable) :
     362        if not isinstance(self.data, scantable) :
    262363            print "Only works with scantables"
    263364            return
Note: See TracChangeset for help on using the changeset viewer.