source: trunk/python/asaplotbase.py @ 1147

Last change on this file since 1147 was 1147, checked in by mar637, 18 years ago

added linecatalog plotting; soem font scaling fixes

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 25.4 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties as FP
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15from matplotlib.ticker import ScalarFormatter
16from matplotlib.ticker import NullLocator
17from matplotlib.transforms import blend_xy_sep_transform
18
19if int(matplotlib.__version__.split(".")[1]) < 87:
20    print "Warning: matplotlib version < 0.87. This might cause errors. Please upgrade."
21
22class MyFormatter(ScalarFormatter):
23    def __call__(self, x, pos=None):
24        #last = len(self.locs)-2
25        if pos==0:
26            return ''
27        else: return ScalarFormatter.__call__(self, x, pos)
28
29class asaplotbase:
30    """
31    ASAP plotting base class based on matplotlib.
32    """
33
34    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
35        """
36        Create a new instance of the ASAPlot plotting class.
37
38        If rows < 1 then a separate call to set_panels() is required to define
39        the panel layout; refer to the doctext for set_panels().
40        """
41        self.is_dead = False
42        self.figure = Figure(figsize=size, facecolor='#ddddee')
43        self.canvas = None
44
45        self.set_title(title)
46        self.subplots = []
47        if rows > 0:
48            self.set_panels(rows, cols)
49
50        # Set matplotlib default colour sequence.
51        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
52
53        c = asaprcParams['plotter.colours']
54        if isinstance(c,str) and len(c) > 0:
55            self.colormap = c.split()
56
57        self.lsalias = {"line":  [1,0],
58                        "dashdot": [4,2,1,2],
59                        "dashed" : [4,2,4,2],
60                        "dotted" : [1,2],
61                        "dashdotdot": [4,2,1,2,1,2],
62                        "dashdashdot": [4,2,4,2,1,2]
63                        }
64
65        styles = "line dashed dotted dashdot".split()
66        c = asaprcParams['plotter.linestyles']
67        if isinstance(c,str) and len(c) > 0:
68            styles = c.split()
69        s = []
70        for ls in styles:
71            if self.lsalias.has_key(ls):
72                s.append(self.lsalias.get(ls))
73            else:
74                s.append('-')
75        self.linestyles = s
76
77        self.color = 0;
78        self.linestyle = 0;
79        self.attributes = {}
80        self.loc = 0
81
82        self.buffering = buffering
83
84    def clear(self):
85        """
86        Delete all lines from the plot.  Line numbering will restart from 0.
87        """
88
89        for i in range(len(self.lines)):
90           self.delete(i)
91        self.axes.clear()
92        self.color = 0
93        self.lines = []
94
95    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
96        if colormap:
97            if isinstance(colormap,list):
98                self.colormap = colormap
99            elif isinstance(colormap,str):
100                self.colormap = colormap.split()
101        if 0 <= color < len(self.colormap):
102            self.color = color
103        if linestyles:
104            self.linestyles = []
105            if isinstance(linestyles,list):
106                styles = linestyles
107            elif isinstance(linestyles,str):
108                styles = linestyles.split()
109            for ls in styles:
110                if self.lsalias.has_key(ls):
111                    self.linestyles.append(self.lsalias.get(ls))
112                else:
113                    self.linestyles.append(self.lsalias.get('line'))
114        if 0 <= linestyle < len(self.linestyles):
115            self.linestyle = linestyle
116
117    def delete(self, numbers=None):
118        """
119        Delete the 0-relative line number, default is to delete the last.
120        The remaining lines are NOT renumbered.
121        """
122
123        if numbers is None: numbers = [len(self.lines)-1]
124
125        if not hasattr(numbers, '__iter__'):
126            numbers = [numbers]
127
128        for number in numbers:
129            if 0 <= number < len(self.lines):
130                if self.lines[number] is not None:
131                    for line in self.lines[number]:
132                        line.set_linestyle('None')
133                        self.lines[number] = None
134        self.show()
135
136    def get_line(self):
137        """
138        Get the current default line attributes.
139        """
140        return self.attributes
141
142
143    def hist(self, x=None, y=None, fmt=None, add=None):
144        """
145        Plot a histogram.  N.B. the x values refer to the start of the
146        histogram bin.
147
148        fmt is the line style as in plot().
149        """
150        from matplotlib.numerix import array
151        from matplotlib.numerix.ma import MaskedArray
152        if x is None:
153            if y is None: return
154            x = range(len(y))
155
156        if len(x) != len(y):
157            return
158        l2 = 2*len(x)
159        x2 = range(l2)
160        y2 = range(12)
161        y2 = range(l2)
162        m2 = range(l2)
163        ymsk = y.raw_mask()
164        ydat = y.raw_data()
165        for i in range(l2):
166            x2[i] = x[i/2]
167            m2[i] = ymsk[i/2]
168
169        y2[0] = 0.0
170        for i in range(1,l2):
171            y2[i] = ydat[(i-1)/2]
172
173        self.plot(x2, MaskedArray(y2,mask=m2,copy=0), fmt, add)
174
175
176    def hold(self, hold=True):
177        """
178        Buffer graphics until subsequently released.
179        """
180        self.buffering = hold
181
182
183    def legend(self, loc=None):
184        """
185        Add a legend to the plot.
186
187        Any other value for loc else disables the legend:
188             1: upper right
189             2: upper left
190             3: lower left
191             4: lower right
192             5: right
193             6: center left
194             7: center right
195             8: lower center
196             9: upper center
197            10: center
198
199        """
200        if isinstance(loc, int):
201            self.loc = None
202            if 0 <= loc <= 10: self.loc = loc
203        else:
204            self.loc = None
205        #self.show()
206
207
208    def plot(self, x=None, y=None, fmt=None, add=None):
209        """
210        Plot the next line in the current frame using the current line
211        attributes.  The ASAPlot graphics window will be mapped and raised.
212
213        The argument list works a bit like the matlab plot() function.
214        """
215        if x is None:
216            if y is None: return
217            x = range(len(y))
218
219        elif y is None:
220            y = x
221            x = range(len(y))
222        if fmt is None:
223            line = self.axes.plot(x, y)
224        else:
225            line = self.axes.plot(x, y, fmt)
226
227        # Add to an existing line?
228        i = None
229        if add is None or len(self.lines) < add < 0:
230            # Don't add.
231            self.lines.append(line)
232            i = len(self.lines) - 1
233        else:
234            if add == 0: add = len(self.lines)
235            i = add - 1
236            self.lines[i].extend(line)
237
238        # Set/reset attributes for the line.
239        gotcolour = False
240        for k, v in self.attributes.iteritems():
241            if k == 'color': gotcolour = True
242            for segment in self.lines[i]:
243                getattr(segment, "set_%s"%k)(v)
244
245        if not gotcolour and len(self.colormap):
246            for segment in self.lines[i]:
247                getattr(segment, "set_color")(self.colormap[self.color])
248                if len(self.colormap)  == 1:
249                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
250
251            self.color += 1
252            if self.color >= len(self.colormap):
253                self.color = 0
254
255            if len(self.colormap) == 1:
256                self.linestyle += 1
257            if self.linestyle >= len(self.linestyles):
258                self.linestyle = 0
259
260        self.show()
261
262
263    def position(self):
264        """
265        Use the mouse to get a position from a graph.
266        """
267
268        def position_disable(event):
269            self.register('button_press', None)
270            print '%.4f, %.4f' % (event.xdata, event.ydata)
271
272        print 'Press any mouse button...'
273        self.register('button_press', position_disable)
274
275
276    def region(self):
277        """
278        Use the mouse to get a rectangular region from a plot.
279
280        The return value is [x0, y0, x1, y1] in world coordinates.
281        """
282
283        def region_start(event):
284            height = self.canvas.figure.bbox.height()
285            self.rect = {'fig': None, 'height': height,
286                         'x': event.x, 'y': height - event.y,
287                         'world': [event.xdata, event.ydata,
288                                   event.xdata, event.ydata]}
289            self.register('button_press', None)
290            self.register('motion_notify', region_draw)
291            self.register('button_release', region_disable)
292
293        def region_draw(event):
294            self.canvas._tkcanvas.delete(self.rect['fig'])
295            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
296                                self.rect['x'], self.rect['y'],
297                                event.x, self.rect['height'] - event.y)
298
299        def region_disable(event):
300            self.register('motion_notify', None)
301            self.register('button_release', None)
302
303            self.canvas._tkcanvas.delete(self.rect['fig'])
304
305            self.rect['world'][2:4] = [event.xdata, event.ydata]
306            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
307                self.rect['world'][1], self.rect['world'][2],
308                self.rect['world'][3])
309
310        self.register('button_press', region_start)
311
312        # This has to be modified to block and return the result (currently
313        # printed by region_disable) when that becomes possible in matplotlib.
314
315        return [0.0, 0.0, 0.0, 0.0]
316
317
318    def register(self, type=None, func=None):
319        """
320        Register, reregister, or deregister events of type 'button_press',
321        'button_release', or 'motion_notify'.
322
323        The specified callback function should have the following signature:
324
325            def func(event)
326
327        where event is an MplEvent instance containing the following data:
328
329            name                # Event name.
330            canvas              # FigureCanvas instance generating the event.
331            x      = None       # x position - pixels from left of canvas.
332            y      = None       # y position - pixels from bottom of canvas.
333            button = None       # Button pressed: None, 1, 2, 3.
334            key    = None       # Key pressed: None, chr(range(255)), shift,
335                                  win, or control
336            inaxes = None       # Axes instance if cursor within axes.
337            xdata  = None       # x world coordinate.
338            ydata  = None       # y world coordinate.
339
340        For example:
341
342            def mouse_move(event):
343                print event.xdata, event.ydata
344
345            a = asaplot()
346            a.register('motion_notify', mouse_move)
347
348        If func is None, the event is deregistered.
349
350        Note that in TkAgg keyboard button presses don't generate an event.
351        """
352
353        if not self.events.has_key(type): return
354
355        if func is None:
356            if self.events[type] is not None:
357                # It's not clear that this does anything.
358                self.canvas.mpl_disconnect(self.events[type])
359                self.events[type] = None
360
361                # It seems to be necessary to return events to the toolbar.
362                if type == 'motion_notify':
363                    self.canvas.mpl_connect(type + '_event',
364                        self.figmgr.toolbar.mouse_move)
365                elif type == 'button_press':
366                    self.canvas.mpl_connect(type + '_event',
367                        self.figmgr.toolbar.press)
368                elif type == 'button_release':
369                    self.canvas.mpl_connect(type + '_event',
370                        self.figmgr.toolbar.release)
371
372        else:
373            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
374
375
376    def release(self):
377        """
378        Release buffered graphics.
379        """
380        self.buffering = False
381        self.show()
382
383
384    def save(self, fname=None, orientation=None, dpi=None, papertype=None):
385        """
386        Save the plot to a file.
387
388        fname is the name of the output file.  The image format is determined
389        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
390        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
391        directory.
392        """
393        from asap import rcParams
394        if papertype is None:
395            papertype = rcParams['plotter.papertype']
396        if fname is None:
397            from datetime import datetime
398            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
399            fname = 'asap'+dstr+'.png'
400
401        d = ['png','.ps','eps']
402
403        from os.path import expandvars
404        fname = expandvars(fname)
405
406        if fname[-3:].lower() in d:
407            try:
408                if fname[-3:].lower() == ".ps":
409                    from matplotlib import __version__ as mv
410                    w = self.figure.figwidth.get()
411                    h = self.figure.figheight.get()
412
413                    if orientation is None:
414                        # oriented
415                        if w > h:
416                            orientation = 'landscape'
417                        else:
418                            orientation = 'portrait'
419                    from matplotlib.backends.backend_ps import papersize
420                    pw,ph = papersize[papertype.lower()]
421                    ds = None
422                    if orientation == 'landscape':
423                        ds = min(ph/w, pw/h)
424                    else:
425                        ds = min(pw/w, ph/h)
426                    ow = ds * w
427                    oh = ds * h
428                    self.figure.set_figsize_inches((ow, oh))
429                    self.figure.savefig(fname, orientation=orientation,
430                                        papertype=papertype.lower())
431                    self.figure.set_figsize_inches((w, h))
432                    print 'Written file %s' % (fname)
433                else:
434                    if dpi is None:
435                        dpi =150
436                    self.figure.savefig(fname,dpi=dpi)
437                    print 'Written file %s' % (fname)
438            except IOError, msg:
439                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
440                return
441        else:
442            print "Invalid image type. Valid types are:"
443            print "'ps', 'eps', 'png'"
444
445
446    def set_axes(self, what=None, *args, **kwargs):
447        """
448        Set attributes for the axes by calling the relevant Axes.set_*()
449        method.  Colour translation is done as described in the doctext
450        for palette().
451        """
452
453        if what is None: return
454        if what[-6:] == 'colour': what = what[:-6] + 'color'
455
456        newargs = {}
457
458        for k, v in kwargs.iteritems():
459            k = k.lower()
460            if k == 'colour': k = 'color'
461            newargs[k] = v
462
463        getattr(self.axes, "set_%s"%what)(*args, **newargs)
464
465        self.show()
466
467
468    def set_figure(self, what=None, *args, **kwargs):
469        """
470        Set attributes for the figure by calling the relevant Figure.set_*()
471        method.  Colour translation is done as described in the doctext
472        for palette().
473        """
474
475        if what is None: return
476        if what[-6:] == 'colour': what = what[:-6] + 'color'
477        #if what[-5:] == 'color' and len(args):
478        #    args = (get_colour(args[0]),)
479
480        newargs = {}
481        for k, v in kwargs.iteritems():
482            k = k.lower()
483            if k == 'colour': k = 'color'
484            newargs[k] = v
485
486        getattr(self.figure, "set_%s"%what)(*args, **newargs)
487        self.show()
488
489
490    def set_limits(self, xlim=None, ylim=None):
491        """
492        Set x-, and y-limits for each subplot.
493
494        xlim = [xmin, xmax] as in axes.set_xlim().
495        ylim = [ymin, ymax] as in axes.set_ylim().
496        """
497        for s in self.subplots:
498            self.axes  = s['axes']
499            self.lines = s['lines']
500            oldxlim =  list(self.axes.get_xlim())
501            oldylim =  list(self.axes.get_ylim())
502            if xlim is not None:
503                for i in range(len(xlim)):
504                    if xlim[i] is not None:
505                        oldxlim[i] = xlim[i]
506            if ylim is not None:
507                for i in range(len(ylim)):
508                    if ylim[i] is not None:
509                        oldylim[i] = ylim[i]
510            self.axes.set_xlim(oldxlim)
511            self.axes.set_ylim(oldylim)
512        return
513
514
515    def set_line(self, number=None, **kwargs):
516        """
517        Set attributes for the specified line, or else the next line(s)
518        to be plotted.
519
520        number is the 0-relative number of a line that has already been
521        plotted.  If no such line exists, attributes are recorded and used
522        for the next line(s) to be plotted.
523
524        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
525
526            import matplotlib
527            help(matplotlib.lines)
528
529        The set_* methods of class Line2D define the attribute names and
530        values.  For non-US usage, "colour" is recognized as synonymous with
531        "color".
532
533        Set the value to None to delete an attribute.
534
535        Colour translation is done as described in the doctext for palette().
536        """
537
538        redraw = False
539        for k, v in kwargs.iteritems():
540            k = k.lower()
541            if k == 'colour': k = 'color'
542
543            if 0 <= number < len(self.lines):
544                if self.lines[number] is not None:
545                    for line in self.lines[number]:
546                        getattr(line, "set_%s"%k)(v)
547                    redraw = True
548            else:
549                if v is None:
550                    del self.attributes[k]
551                else:
552                    self.attributes[k] = v
553
554        if redraw: self.show()
555
556
557    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
558        """
559        Set the panel layout.
560
561        rows and cols, if cols != 0, specify the number of rows and columns in
562        a regular layout.   (Indexing of these panels in matplotlib is row-
563        major, i.e. column varies fastest.)
564
565        cols == 0 is interpreted as a retangular layout that accomodates
566        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
567        rows == 2, cols == 3.
568
569        0 <= n < rows*cols is interpreted as the 0-relative panel number in
570        the configuration specified by rows and cols to be added to the
571        current figure as its next 0-relative panel number (i).  This allows
572        non-regular panel layouts to be constructed via multiple calls.  Any
573        other value of n clears the plot and produces a rectangular array of
574        empty panels.  The number of these may be limited by nplots.
575        """
576        if n < 0 and len(self.subplots):
577            self.figure.clear()
578            self.set_title()
579
580        if rows < 1: rows = 1
581
582        if cols <= 0:
583            i = int(sqrt(rows))
584            if i*i < rows: i += 1
585            cols = i
586
587            if i*(i-1) >= rows: i -= 1
588            rows = i
589
590        if 0 <= n < rows*cols:
591            i = len(self.subplots)
592            self.subplots.append({})
593
594            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
595                                            cols, n+1)
596            self.subplots[i]['lines'] = []
597
598            if i == 0: self.subplot(0)
599
600            self.rows = 0
601            self.cols = 0
602
603        else:
604            self.subplots = []
605
606            if nplots < 1 or rows*cols < nplots:
607                nplots = rows*cols
608            if ganged:
609                hsp,wsp = None,None
610                if rows > 1: hsp = 0.0001
611                if cols > 1: wsp = 0.0001
612                self.figure.subplots_adjust(wspace=wsp,hspace=hsp)
613            for i in range(nplots):
614                self.subplots.append({})
615                self.subplots[i]['axes'] = self.figure.add_subplot(rows,
616                                                cols, i+1)
617                self.subplots[i]['lines'] = []
618
619                if ganged:
620                    # Suppress tick labelling for interior subplots.
621                    if i <= (rows-1)*cols - 1:
622                        if i+cols < nplots:
623                            # Suppress x-labels for frames width
624                            # adjacent frames
625                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
626                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
627                    if i%cols:
628                        # Suppress y-labels for frames not in the left column.
629                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
630                            tick.label1On = False
631                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
632                    # disable the first tick of [1:ncol-1] of the last row
633                    if (nplots-cols) < i <= nplots-1:
634                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
635                self.rows = rows
636                self.cols = cols
637            self.subplot(0)
638
639    def set_title(self, title=None):
640        """
641        Set the title of the plot window.  Use the previous title if title is
642        omitted.
643        """
644        if title is not None:
645            self.title = title
646
647        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
648
649
650    def show(self):
651        """
652        Show graphics dependent on the current buffering state.
653        """
654        if not self.buffering:
655            if self.loc is not None:
656                for sp in self.subplots:
657                    lines  = []
658                    labels = []
659                    i = 0
660                    for line in sp['lines']:
661                        i += 1
662                        if line is not None:
663                            lines.append(line[0])
664                            lbl = line[0].get_label()
665                            if lbl == '':
666                                lbl = str(i)
667                            labels.append(lbl)
668
669                    if len(lines):
670                        fp = FP(size=rcParams['legend.fontsize'])
671                        fsz = fp.get_size_in_points() - len(lines)
672                        fp.set_size(max(fsz,6))
673                        sp['axes'].legend(tuple(lines), tuple(labels),
674                                          self.loc, prop=fp)
675                    else:
676                        sp['axes'].legend((' '))
677
678            from matplotlib.artist import setp
679            fp = FP(size=rcParams['xtick.labelsize'])
680            xts = fp.get_size_in_points()- (self.cols)/2
681            fp = FP(size=rcParams['ytick.labelsize'])
682            yts = fp.get_size_in_points() - (self.rows)/2
683            for sp in self.subplots:
684                ax = sp['axes']
685                s = ax.title.get_size()
686                tsize = s-(self.cols+self.rows)
687                ax.title.set_size(tsize)
688                fp = FP(size=rcParams['axes.labelsize'])
689                setp(ax.get_xticklabels(), fontsize=xts)
690                setp(ax.get_yticklabels(), fontsize=yts)
691                origx =  fp.get_size_in_points()
692                origy = origx
693                off = 0
694                if self.cols > 1: off = self.cols
695                xfsize = origx-off
696                ax.xaxis.label.set_size(xfsize)
697                off = 0
698                if self.rows > 1: off = self.rows
699                yfsize = origy-off
700                ax.yaxis.label.set_size(yfsize)
701
702    def subplot(self, i=None, inc=None):
703        """
704        Set the subplot to the 0-relative panel number as defined by one or
705        more invokations of set_panels().
706        """
707        l = len(self.subplots)
708        if l:
709            if i is not None:
710                self.i = i
711
712            if inc is not None:
713                self.i += inc
714
715            self.i %= l
716            self.axes  = self.subplots[self.i]['axes']
717            self.lines = self.subplots[self.i]['lines']
718
719    def text(self, *args, **kwargs):
720        """
721        Add text to the figure.
722        """
723        self.figure.text(*args, **kwargs)
724        self.show()
725
726    def vline_with_label(self, x, y, label,
727                         location='bottom', rotate=0.0, **kwargs):
728        """
729        Plot a vertical line with label.
730        It takes "world" values fo x and y.
731        """
732        ax = self.axes
733        # need this to suppress autoscaling during this function
734        self.axes.set_autoscale_on(False)
735        ymin = 0.0
736        ymax = 1.0
737        valign = 'center'
738        if location.lower() == 'top':
739            y = max(0.0, y)
740        elif location.lower() == 'bottom':
741            y = min(0.0, y)
742        lbloffset = 0.06
743        # a rough estimate for the bb of the text
744        if rotate > 0.0: lbloffset = 0.03*len(label)
745        peakoffset = 0.01
746        xy0 = ax.transData.xy_tup((x,y))
747        # get relative coords
748        xy = ax.transAxes.inverse_xy_tup(xy0)
749        if location.lower() == 'top':
750            ymax = 1.0-lbloffset
751            ymin = xy[1]+peakoffset
752            valign = 'bottom'
753            ylbl = ymax+0.01
754        elif location.lower() == 'bottom':
755            ymin = lbloffset
756            ymax = xy[1]-peakoffset
757            valign = 'top'
758            ylbl = ymin-0.01
759        trans = blend_xy_sep_transform(ax.transData, ax.transAxes)
760        l = ax.axvline(x, ymin, ymax, color='black', **kwargs)
761        t = ax.text(x, ylbl ,label, verticalalignment=valign,
762                                    horizontalalignment='center',
763                    rotation=rotate,transform = trans)
764        self.axes.set_autoscale_on(True)
Note: See TracBrowser for help on using the repository browser.