source: trunk/python/asaplotbase.py @ 1023

Last change on this file since 1023 was 1023, checked in by mar637, 18 years ago

The previous histogram plot was mutually exclusive with linestyle, so I am using asapplotbase.hist intead.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.3 KB
RevLine 
[705]1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
[710]14from asap import rcParams as asaprcParams
[705]15
[1019]16from matplotlib.ticker import ScalarFormatter
17from matplotlib.ticker import NullLocator
18
19class MyFormatter(ScalarFormatter):
20    def __call__(self, x, pos=None):
21        last = len(self.locs)-2
22        if pos==last:
23            print "Diabling tick no " , pos, last
24            return ''  # pos=-1 is the last tick
25        else: return ScalarFormatter.__call__(self, x, pos)
26
[705]27class asaplotbase:
28    """
29    ASAP plotting base class based on matplotlib.
30    """
31
32    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
[1019]33        """
34        Create a new instance of the ASAPlot plotting class.
[705]35
[1019]36        If rows < 1 then a separate call to set_panels() is required to define
37        the panel layout; refer to the doctext for set_panels().
38        """
[705]39        self.is_dead = False
[1019]40        self.figure = Figure(figsize=size, facecolor='#ddddee')
[705]41        self.canvas = None
42
[1019]43        self.set_title(title)
44        self.subplots = []
45        if rows > 0:
46            self.set_panels(rows, cols)
[705]47
[710]48        # Set matplotlib default colour sequence.
49        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
[1019]50
[710]51        c = asaprcParams['plotter.colours']
52        if isinstance(c,str) and len(c) > 0:
53            self.colormap = c.split()
54
55        self.lsalias = {"line":  [1,0],
56                        "dashdot": [4,2,1,2],
57                        "dashed" : [4,2,4,2],
58                        "dotted" : [1,2],
59                        "dashdotdot": [4,2,1,2,1,2],
60                        "dashdashdot": [4,2,4,2,1,2]
61                        }
62
63        styles = "line dashed dotted dashdot".split()
64        c = asaprcParams['plotter.linestyles']
65        if isinstance(c,str) and len(c) > 0:
66            styles = c.split()
67        s = []
68        for ls in styles:
69            if self.lsalias.has_key(ls):
70                s.append(self.lsalias.get(ls))
71            else:
72                s.append('-')
73        self.linestyles = s
74
[705]75        self.color = 0;
[710]76        self.linestyle = 0;
[1019]77        self.attributes = {}
78        self.loc = 0
[705]79
[1019]80        self.buffering = buffering
[705]81
82    def clear(self):
[1019]83        """
84        Delete all lines from the plot.  Line numbering will restart from 1.
85        """
[705]86
[1019]87        for i in range(len(self.lines)):
88           self.delete(i)
89        self.axes.clear()
90        self.color = 0
91        self.lines = []
[705]92
[710]93    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
[705]94        if colormap:
[710]95            if isinstance(colormap,list):
96                self.colormap = colormap
97            elif isinstance(colormap,str):
98                self.colormap = colormap.split()
[705]99        if 0 <= color < len(self.colormap):
100            self.color = color
[710]101        if linestyles:
102            self.linestyles = []
103            if isinstance(linestyles,list):
104                styles = linestyles
105            elif isinstance(linestyles,str):
106                styles = linestyles.split()
107            for ls in styles:
108                if self.lsalias.has_key(ls):
109                    self.linestyles.append(self.lsalias.get(ls))
110                else:
111                    self.linestyles.append(self.lsalias.get('line'))
112        if 0 <= linestyle < len(self.linestyles):
113            self.linestyle = linestyle
[705]114
115    def delete(self, numbers=None):
[1019]116        """
117        Delete the 0-relative line number, default is to delete the last.
118        The remaining lines are NOT renumbered.
119        """
[705]120
[1019]121        if numbers is None: numbers = [len(self.lines)-1]
[705]122
[1019]123        if not hasattr(numbers, '__iter__'):
124            numbers = [numbers]
[705]125
[1019]126        for number in numbers:
127            if 0 <= number < len(self.lines):
128                if self.lines[number] is not None:
129                    for line in self.lines[number]:
130                        line.set_linestyle('None')
131                        self.lines[number] = None
132        self.show()
[705]133
134    def get_line(self):
[1019]135        """
136        Get the current default line attributes.
137        """
138        return self.attributes
[705]139
140
[1023]141    def hist(self, x=None, y=None, msk=None, fmt=None, add=None):
[1019]142        """
143        Plot a histogram.  N.B. the x values refer to the start of the
144        histogram bin.
[705]145
[1019]146        fmt is the line style as in plot().
147        """
[705]148
[1019]149        if x is None:
150            if y is None: return
[1023]151            x = range(len(y))
[705]152
[1019]153        if len(x) != len(y):
154            return
155        l2 = 2*len(x)
[1023]156        x2 = range(l2)
157        y2 = range(l2)
158        m2 = range(l2)
[705]159
[1023]160        for i in range(l2):
[1019]161            x2[i] = x[i/2]
[1023]162            m2[i] = msk[i/2]
[705]163
[1023]164        y2[0] = 0.0
[1019]165        for i in range(1,l2):
166            y2[i] = y[(i-1)/2]
[705]167
[1023]168        self.plot(x2, y2, m2, fmt, add)
[705]169
170
171    def hold(self, hold=True):
[1019]172        """
173        Buffer graphics until subsequently released.
174        """
175        self.buffering = hold
[705]176
177
178    def legend(self, loc=None):
[1019]179        """
180        Add a legend to the plot.
[705]181
[1019]182        Any other value for loc else disables the legend:
183             1: upper right
184             2: upper left
185             3: lower left
186             4: lower right
187             5: right
188             6: center left
189             7: center right
190             8: lower center
191             9: upper center
192            10: center
[705]193
[1019]194        """
[705]195        if isinstance(loc,int):
196            if 0 > loc > 10: loc = 0
197            self.loc = loc
[1019]198        self.show()
[705]199
200
201    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
[1019]202        """
203        Plot the next line in the current frame using the current line
204        attributes.  The ASAPlot graphics window will be mapped and raised.
[705]205
[1019]206        The argument list works a bit like the matlab plot() function.
207        """
[705]208
[1019]209        if x is None:
210            if y is None: return
211            x = range(len(y))
[705]212
[1019]213        elif y is None:
214            y = x
215            x = range(len(y))
[705]216
[1019]217        if mask is None:
218            if fmt is None:
219                line = self.axes.plot(x, y)
220            else:
221                line = self.axes.plot(x, y, fmt)
222        else:
223            segments = []
[705]224
[1019]225            mask = list(mask)
226            i = 0
227            while mask[i:].count(1):
228                i += mask[i:].index(1)
229                if mask[i:].count(0):
230                    j = i + mask[i:].index(0)
231                else:
232                    j = len(mask)
[705]233
[1019]234                segments.append(x[i:j])
235                segments.append(y[i:j])
[705]236
[1019]237                i = j
[705]238
[1019]239            line = self.axes.plot(*segments)
240        # Add to an existing line?
241        if add is None or len(self.lines) < add < 0:
242            # Don't add.
243            self.lines.append(line)
244            i = len(self.lines) - 1
245        else:
246            if add == 0: add = len(self.lines)
247            i = add - 1
248            self.lines[i].extend(line)
[705]249
[1019]250        # Set/reset attributes for the line.
251        gotcolour = False
252        for k, v in self.attributes.iteritems():
253            if k == 'color': gotcolour = True
254            for segment in self.lines[i]:
255                getattr(segment, "set_%s"%k)(v)
[705]256
[1019]257        if not gotcolour and len(self.colormap):
258            for segment in self.lines[i]:
259                getattr(segment, "set_color")(self.colormap[self.color])
[710]260                if len(self.colormap)  == 1:
261                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
[1019]262            self.color += 1
263            if self.color >= len(self.colormap):
264                self.color = 0
[705]265
[710]266            if len(self.colormap) == 1:
267                self.linestyle += 1
[1019]268            if self.linestyle >= len(self.linestyles):
269                self.linestyle = 0
[710]270
[1019]271        self.show()
[705]272
273
274    def position(self):
[1019]275        """
276        Use the mouse to get a position from a graph.
277        """
[705]278
[1019]279        def position_disable(event):
280            self.register('button_press', None)
281            print '%.4f, %.4f' % (event.xdata, event.ydata)
[705]282
[1019]283        print 'Press any mouse button...'
284        self.register('button_press', position_disable)
[705]285
286
287    def region(self):
[1019]288        """
289        Use the mouse to get a rectangular region from a plot.
[705]290
[1019]291        The return value is [x0, y0, x1, y1] in world coordinates.
292        """
[705]293
[1019]294        def region_start(event):
295            height = self.canvas.figure.bbox.height()
296            self.rect = {'fig': None, 'height': height,
297                         'x': event.x, 'y': height - event.y,
298                         'world': [event.xdata, event.ydata,
299                                   event.xdata, event.ydata]}
300            self.register('button_press', None)
301            self.register('motion_notify', region_draw)
302            self.register('button_release', region_disable)
[705]303
[1019]304        def region_draw(event):
305            self.canvas._tkcanvas.delete(self.rect['fig'])
306            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
307                                self.rect['x'], self.rect['y'],
308                                event.x, self.rect['height'] - event.y)
[705]309
[1019]310        def region_disable(event):
311            self.register('motion_notify', None)
312            self.register('button_release', None)
[705]313
[1019]314            self.canvas._tkcanvas.delete(self.rect['fig'])
[705]315
[1019]316            self.rect['world'][2:4] = [event.xdata, event.ydata]
317            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
318                self.rect['world'][1], self.rect['world'][2],
319                self.rect['world'][3])
[705]320
[1019]321        self.register('button_press', region_start)
[705]322
[1019]323        # This has to be modified to block and return the result (currently
324        # printed by region_disable) when that becomes possible in matplotlib.
[705]325
[1019]326        return [0.0, 0.0, 0.0, 0.0]
[705]327
328
329    def register(self, type=None, func=None):
[1019]330        """
331        Register, reregister, or deregister events of type 'button_press',
332        'button_release', or 'motion_notify'.
[705]333
[1019]334        The specified callback function should have the following signature:
[705]335
[1019]336            def func(event)
[705]337
[1019]338        where event is an MplEvent instance containing the following data:
[705]339
[1019]340            name                # Event name.
341            canvas              # FigureCanvas instance generating the event.
342            x      = None       # x position - pixels from left of canvas.
343            y      = None       # y position - pixels from bottom of canvas.
344            button = None       # Button pressed: None, 1, 2, 3.
345            key    = None       # Key pressed: None, chr(range(255)), shift,
346                                  win, or control
347            inaxes = None       # Axes instance if cursor within axes.
348            xdata  = None       # x world coordinate.
349            ydata  = None       # y world coordinate.
[705]350
[1019]351        For example:
[705]352
[1019]353            def mouse_move(event):
354                print event.xdata, event.ydata
[705]355
[1019]356            a = asaplot()
357            a.register('motion_notify', mouse_move)
[705]358
[1019]359        If func is None, the event is deregistered.
[705]360
[1019]361        Note that in TkAgg keyboard button presses don't generate an event.
362        """
[705]363
[1019]364        if not self.events.has_key(type): return
[705]365
[1019]366        if func is None:
367            if self.events[type] is not None:
368                # It's not clear that this does anything.
369                self.canvas.mpl_disconnect(self.events[type])
370                self.events[type] = None
[705]371
[1019]372                # It seems to be necessary to return events to the toolbar.
373                if type == 'motion_notify':
374                    self.canvas.mpl_connect(type + '_event',
375                        self.figmgr.toolbar.mouse_move)
376                elif type == 'button_press':
377                    self.canvas.mpl_connect(type + '_event',
378                        self.figmgr.toolbar.press)
379                elif type == 'button_release':
380                    self.canvas.mpl_connect(type + '_event',
381                        self.figmgr.toolbar.release)
[705]382
[1019]383        else:
384            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
[705]385
386
387    def release(self):
[1019]388        """
389        Release buffered graphics.
390        """
391        self.buffering = False
392        self.show()
[705]393
394
395    def save(self, fname=None, orientation=None, dpi=None):
[1019]396        """
397        Save the plot to a file.
[705]398
[1019]399        fname is the name of the output file.  The image format is determined
400        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
401        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
402        directory.
403        """
404        if fname is None:
405            from datetime import datetime
406            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
407            fname = 'asap'+dstr+'.png'
[705]408
[1019]409        d = ['png','.ps','eps']
[705]410
[1019]411        from os.path import expandvars
412        fname = expandvars(fname)
[705]413
[1019]414        if fname[-3:].lower() in d:
415            try:
[705]416                if fname[-3:].lower() == ".ps":
[1020]417                    from matplotlib import __version__ as mv
[705]418                    w = self.figure.figwidth.get()
[1019]419                    h = self.figure.figheight.get()
420
[705]421                    if orientation is None:
422                        # auto oriented
423                        if w > h:
424                            orientation = 'landscape'
425                        else:
426                            orientation = 'portrait'
[1020]427                    # hack to circument ps bug in eraly versions of mpl
428                    if int(mv.split(".")[1]) < 86:
429                        a4w = 8.25
430                        a4h = 11.25
431                        ds = None
432                        if orientation == 'landscape':
433                            ds = min(a4h/w,a4w/h)
434                        else:
435                            ds = min(a4w/w,a4h/h)
436                        ow = ds * w
437                        oh = ds * h
438                        self.figure.set_figsize_inches((ow,oh))
[705]439                    self.canvas.print_figure(fname,orientation=orientation)
440                    print 'Written file %s' % (fname)
[1019]441                else:
[705]442                    if dpi is None:
443                        dpi =150
444                    self.canvas.print_figure(fname,dpi=dpi)
445                    print 'Written file %s' % (fname)
[1019]446            except IOError, msg:
447                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
448                return
449        else:
450            print "Invalid image type. Valid types are:"
451            print "'ps', 'eps', 'png'"
[705]452
453
454    def set_axes(self, what=None, *args, **kwargs):
[1019]455        """
456        Set attributes for the axes by calling the relevant Axes.set_*()
457        method.  Colour translation is done as described in the doctext
458        for palette().
459        """
[705]460
[1019]461        if what is None: return
462        if what[-6:] == 'colour': what = what[:-6] + 'color'
[705]463
[1019]464        newargs = {}
[705]465
[1019]466        for k, v in kwargs.iteritems():
467            k = k.lower()
468            if k == 'colour': k = 'color'
469            newargs[k] = v
[705]470
[1019]471        getattr(self.axes, "set_%s"%what)(*args, **newargs)
472        s = self.axes.title.get_size()
473        tsize = s-(self.cols+self.rows)/2-1
474        self.axes.title.set_size(tsize)
475        if self.cols > 1:
476            xfsize = self.axes.xaxis.label.get_size()-(self.cols+1)/2
477            self.axes.xaxis.label.set_size(xfsize)
478        if self.rows > 1:
479            yfsize = self.axes.yaxis.label.get_size()-(self.rows+1)/2
480            self.axes.yaxis.label.set_size(yfsize)
[705]481
[1019]482        self.show()
483
484
[705]485    def set_figure(self, what=None, *args, **kwargs):
[1019]486        """
487        Set attributes for the figure by calling the relevant Figure.set_*()
488        method.  Colour translation is done as described in the doctext
489        for palette().
490        """
[705]491
[1019]492        if what is None: return
493        if what[-6:] == 'colour': what = what[:-6] + 'color'
494        #if what[-5:] == 'color' and len(args):
495        #    args = (get_colour(args[0]),)
[705]496
[1019]497        newargs = {}
498        for k, v in kwargs.iteritems():
499            k = k.lower()
500            if k == 'colour': k = 'color'
501            newargs[k] = v
[705]502
[1019]503        getattr(self.figure, "set_%s"%what)(*args, **newargs)
504        self.show()
[705]505
506
507    def set_limits(self, xlim=None, ylim=None):
[1019]508        """
509        Set x-, and y-limits for each subplot.
[705]510
[1019]511        xlim = [xmin, xmax] as in axes.set_xlim().
512        ylim = [ymin, ymax] as in axes.set_ylim().
513        """
514        for s in self.subplots:
515            self.axes  = s['axes']
516            self.lines = s['lines']
[705]517            oldxlim =  list(self.axes.get_xlim())
518            oldylim =  list(self.axes.get_ylim())
519            if xlim is not None:
520                for i in range(len(xlim)):
521                    if xlim[i] is not None:
522                        oldxlim[i] = xlim[i]
[1019]523            if ylim is not None:
[705]524                for i in range(len(ylim)):
525                    if ylim[i] is not None:
526                        oldylim[i] = ylim[i]
527            self.axes.set_xlim(oldxlim)
528            self.axes.set_ylim(oldylim)
529        return
530
531
532    def set_line(self, number=None, **kwargs):
[1019]533        """
534        Set attributes for the specified line, or else the next line(s)
535        to be plotted.
[705]536
[1019]537        number is the 0-relative number of a line that has already been
538        plotted.  If no such line exists, attributes are recorded and used
539        for the next line(s) to be plotted.
[705]540
[1019]541        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
[705]542
[1019]543            import matplotlib
544            help(matplotlib.lines)
[705]545
[1019]546        The set_* methods of class Line2D define the attribute names and
547        values.  For non-US usage, "colour" is recognized as synonymous with
548        "color".
[705]549
[1019]550        Set the value to None to delete an attribute.
[705]551
[1019]552        Colour translation is done as described in the doctext for palette().
553        """
[705]554
[1019]555        redraw = False
556        for k, v in kwargs.iteritems():
557            k = k.lower()
558            if k == 'colour': k = 'color'
[705]559
[1019]560            if 0 <= number < len(self.lines):
561                if self.lines[number] is not None:
562                    for line in self.lines[number]:
563                        getattr(line, "set_%s"%k)(v)
564                    redraw = True
565            else:
566                if v is None:
567                    del self.attributes[k]
568                else:
569                    self.attributes[k] = v
[705]570
[1019]571        if redraw: self.show()
[705]572
573
574    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
[1019]575        """
576        Set the panel layout.
[705]577
[1019]578        rows and cols, if cols != 0, specify the number of rows and columns in
579        a regular layout.   (Indexing of these panels in matplotlib is row-
580        major, i.e. column varies fastest.)
[705]581
[1019]582        cols == 0 is interpreted as a retangular layout that accomodates
583        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
584        rows == 2, cols == 3.
[705]585
[1019]586        0 <= n < rows*cols is interpreted as the 0-relative panel number in
587        the configuration specified by rows and cols to be added to the
588        current figure as its next 0-relative panel number (i).  This allows
589        non-regular panel layouts to be constructed via multiple calls.  Any
590        other value of n clears the plot and produces a rectangular array of
591        empty panels.  The number of these may be limited by nplots.
592        """
593        if n < 0 and len(self.subplots):
594            self.figure.clear()
595            self.set_title()
[705]596
[1019]597        if rows < 1: rows = 1
[705]598
[1019]599        if cols <= 0:
600            i = int(sqrt(rows))
601            if i*i < rows: i += 1
602            cols = i
[705]603
[1019]604            if i*(i-1) >= rows: i -= 1
605            rows = i
[705]606
[1019]607        if 0 <= n < rows*cols:
608            i = len(self.subplots)
609            self.subplots.append({})
[705]610
[1019]611            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
612                                            cols, n+1)
613            self.subplots[i]['lines'] = []
[705]614
[1019]615            if i == 0: self.subplot(0)
[705]616
[1019]617            self.rows = 0
618            self.cols = 0
[705]619
[1019]620        else:
621            self.subplots = []
[705]622
[1019]623            if nplots < 1 or rows*cols < nplots:
624                nplots = rows*cols
[705]625
[1019]626            for i in range(nplots):
627                self.subplots.append({})
[705]628
[1019]629                self.subplots[i]['axes']  = self.figure.add_subplot(rows,
630                                                cols, i+1)
631                self.subplots[i]['lines'] = []
632
[705]633                if ganged:
634                    if rows > 1 or cols > 1:
635                        # Squeeze the plots together.
636                        pos = self.subplots[i]['axes'].get_position()
637                        if cols > 1: pos[2] *= 1.2
638                        if rows > 1: pos[3] *= 1.2
639                        self.subplots[i]['axes'].set_position(pos)
640
641                    # Suppress tick labelling for interior subplots.
642                    if i <= (rows-1)*cols - 1:
643                        if i+cols < nplots:
644                            # Suppress x-labels for frames width
645                            # adjacent frames
[1019]646                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
647                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
[705]648                    if i%cols:
649                        # Suppress y-labels for frames not in the left column.
650                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
651                            tick.label1On = False
652                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
[1019]653                    if (i+1)%cols:
654                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
655                self.rows = rows
656                self.cols = cols
[705]657
[1019]658            self.subplot(0)
[705]659
660    def set_title(self, title=None):
[1019]661        """
662        Set the title of the plot window.  Use the previous title if title is
663        omitted.
664        """
665        if title is not None:
666            self.title = title
[705]667
[1019]668        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
[705]669
670
671    def show(self):
[1019]672        """
673        Show graphics dependent on the current buffering state.
674        """
675        if not self.buffering:
676            if self.loc is not None:
677                for j in range(len(self.subplots)):
678                    lines  = []
679                    labels = []
680                    i = 0
681                    for line in self.subplots[j]['lines']:
682                        i += 1
683                        if line is not None:
684                            lines.append(line[0])
685                            lbl = line[0].get_label()
686                            if lbl == '':
687                                lbl = str(i)
688                            labels.append(lbl)
[705]689
[1019]690                    if len(lines):
691                        self.subplots[j]['axes'].legend(tuple(lines),
692                                                        tuple(labels),
693                                                        self.loc)
694                    else:
695                        self.subplots[j]['axes'].legend((' '))
[705]696
697
698    def subplot(self, i=None, inc=None):
[1019]699        """
700        Set the subplot to the 0-relative panel number as defined by one or
701        more invokations of set_panels().
702        """
703        l = len(self.subplots)
704        if l:
705            if i is not None:
706                self.i = i
[705]707
[1019]708            if inc is not None:
709                self.i += inc
[705]710
[1019]711            self.i %= l
712            self.axes  = self.subplots[self.i]['axes']
713            self.lines = self.subplots[self.i]['lines']
[705]714
715
716    def text(self, *args, **kwargs):
[1019]717        """
718        Add text to the figure.
719        """
720        self.figure.text(*args, **kwargs)
721        self.show()
Note: See TracBrowser for help on using the repository browser.