source: trunk/python/asaplotbase.py @ 1032

Last change on this file since 1032 was 1032, checked in by mar637, 18 years ago

Fix for Ticket #32; re-introduced mpl version dependency

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.6 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15from matplotlib.ticker import ScalarFormatter
16from matplotlib.ticker import NullLocator
17
18class MyFormatter(ScalarFormatter):
19    def __call__(self, x, pos=None):
20        #last = len(self.locs)-2
21        if pos==0:
22            return ''
23        else: return ScalarFormatter.__call__(self, x, pos)
24
25class asaplotbase:
26    """
27    ASAP plotting base class based on matplotlib.
28    """
29
30    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
31        """
32        Create a new instance of the ASAPlot plotting class.
33
34        If rows < 1 then a separate call to set_panels() is required to define
35        the panel layout; refer to the doctext for set_panels().
36        """
37        self.is_dead = False
38        self.figure = Figure(figsize=size, facecolor='#ddddee')
39        self.canvas = None
40
41        self.set_title(title)
42        self.subplots = []
43        if rows > 0:
44            self.set_panels(rows, cols)
45
46        # Set matplotlib default colour sequence.
47        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
48
49        c = asaprcParams['plotter.colours']
50        if isinstance(c,str) and len(c) > 0:
51            self.colormap = c.split()
52
53        self.lsalias = {"line":  [1,0],
54                        "dashdot": [4,2,1,2],
55                        "dashed" : [4,2,4,2],
56                        "dotted" : [1,2],
57                        "dashdotdot": [4,2,1,2,1,2],
58                        "dashdashdot": [4,2,4,2,1,2]
59                        }
60
61        styles = "line dashed dotted dashdot".split()
62        c = asaprcParams['plotter.linestyles']
63        if isinstance(c,str) and len(c) > 0:
64            styles = c.split()
65        s = []
66        for ls in styles:
67            if self.lsalias.has_key(ls):
68                s.append(self.lsalias.get(ls))
69            else:
70                s.append('-')
71        self.linestyles = s
72
73        self.color = 0;
74        self.linestyle = 0;
75        self.attributes = {}
76        self.loc = 0
77
78        self.buffering = buffering
79
80    def clear(self):
81        """
82        Delete all lines from the plot.  Line numbering will restart from 1.
83        """
84
85        for i in range(len(self.lines)):
86           self.delete(i)
87        self.axes.clear()
88        self.color = 0
89        self.lines = []
90
91    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
92        if colormap:
93            if isinstance(colormap,list):
94                self.colormap = colormap
95            elif isinstance(colormap,str):
96                self.colormap = colormap.split()
97        if 0 <= color < len(self.colormap):
98            self.color = color
99        if linestyles:
100            self.linestyles = []
101            if isinstance(linestyles,list):
102                styles = linestyles
103            elif isinstance(linestyles,str):
104                styles = linestyles.split()
105            for ls in styles:
106                if self.lsalias.has_key(ls):
107                    self.linestyles.append(self.lsalias.get(ls))
108                else:
109                    self.linestyles.append(self.lsalias.get('line'))
110        if 0 <= linestyle < len(self.linestyles):
111            self.linestyle = linestyle
112
113    def delete(self, numbers=None):
114        """
115        Delete the 0-relative line number, default is to delete the last.
116        The remaining lines are NOT renumbered.
117        """
118
119        if numbers is None: numbers = [len(self.lines)-1]
120
121        if not hasattr(numbers, '__iter__'):
122            numbers = [numbers]
123
124        for number in numbers:
125            if 0 <= number < len(self.lines):
126                if self.lines[number] is not None:
127                    for line in self.lines[number]:
128                        line.set_linestyle('None')
129                        self.lines[number] = None
130        self.show()
131
132    def get_line(self):
133        """
134        Get the current default line attributes.
135        """
136        return self.attributes
137
138
139    def hist(self, x=None, y=None, msk=None, fmt=None, add=None):
140        """
141        Plot a histogram.  N.B. the x values refer to the start of the
142        histogram bin.
143
144        fmt is the line style as in plot().
145        """
146
147        if x is None:
148            if y is None: return
149            x = range(len(y))
150
151        if len(x) != len(y):
152            return
153        l2 = 2*len(x)
154        x2 = range(l2)
155        y2 = range(l2)
156        m2 = range(l2)
157
158        for i in range(l2):
159            x2[i] = x[i/2]
160            m2[i] = msk[i/2]
161
162        y2[0] = 0.0
163        for i in range(1,l2):
164            y2[i] = y[(i-1)/2]
165
166        self.plot(x2, y2, m2, fmt, add)
167
168
169    def hold(self, hold=True):
170        """
171        Buffer graphics until subsequently released.
172        """
173        self.buffering = hold
174
175
176    def legend(self, loc=None):
177        """
178        Add a legend to the plot.
179
180        Any other value for loc else disables the legend:
181             1: upper right
182             2: upper left
183             3: lower left
184             4: lower right
185             5: right
186             6: center left
187             7: center right
188             8: lower center
189             9: upper center
190            10: center
191
192        """
193        if isinstance(loc,int):
194            if 0 > loc > 10: loc = 0
195            self.loc = loc
196        self.show()
197
198
199    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
200        """
201        Plot the next line in the current frame using the current line
202        attributes.  The ASAPlot graphics window will be mapped and raised.
203
204        The argument list works a bit like the matlab plot() function.
205        """
206        if x is None:
207            if y is None: return
208            x = range(len(y))
209
210        elif y is None:
211            y = x
212            x = range(len(y))
213
214        ax = self.axes
215        s = ax.title.get_size()
216        tsize = s-(self.cols+self.rows)/2
217        ax.title.set_size(tsize)
218        origx = ax.xaxis.label.get_size() #rcParams['xtick.labelsize']
219        origy = ax.yaxis.label.get_size() #rcParams['ytick.labelsize']
220        if self.cols > 1:
221            xfsize = origx-(self.cols)/2
222            #rc('xtick',labelsize=xfsize)
223            ax.xaxis.label.set_size(xfsize)
224        if self.rows > 1:
225            yfsize = origy-(self.rows)/2
226            #rc('ytick',labelsize=yfsize)
227            ax.yaxis.label.set_size(yfsize)
228        if mask is None:
229            if fmt is None:
230                line = self.axes.plot(x, y)
231            else:
232                line = self.axes.plot(x, y, fmt)
233        else:
234            segments = []
235
236            mask = list(mask)
237            i = 0
238            while mask[i:].count(1):
239                i += mask[i:].index(1)
240                if mask[i:].count(0):
241                    j = i + mask[i:].index(0)
242                else:
243                    j = len(mask)
244
245                segments.append(x[i:j])
246                segments.append(y[i:j])
247
248                i = j
249
250            line = self.axes.plot(*segments)
251        #rc('xtick',labelsize=origx)
252        #rc('ytick',labelsize=origy)
253
254        # Add to an existing line?
255        if add is None or len(self.lines) < add < 0:
256            # Don't add.
257            self.lines.append(line)
258            i = len(self.lines) - 1
259        else:
260            if add == 0: add = len(self.lines)
261            i = add - 1
262            self.lines[i].extend(line)
263
264        # Set/reset attributes for the line.
265        gotcolour = False
266        for k, v in self.attributes.iteritems():
267            if k == 'color': gotcolour = True
268            for segment in self.lines[i]:
269                getattr(segment, "set_%s"%k)(v)
270
271        if not gotcolour and len(self.colormap):
272            for segment in self.lines[i]:
273                getattr(segment, "set_color")(self.colormap[self.color])
274                if len(self.colormap)  == 1:
275                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
276            self.color += 1
277            if self.color >= len(self.colormap):
278                self.color = 0
279
280            if len(self.colormap) == 1:
281                self.linestyle += 1
282            if self.linestyle >= len(self.linestyles):
283                self.linestyle = 0
284
285        self.show()
286
287
288    def position(self):
289        """
290        Use the mouse to get a position from a graph.
291        """
292
293        def position_disable(event):
294            self.register('button_press', None)
295            print '%.4f, %.4f' % (event.xdata, event.ydata)
296
297        print 'Press any mouse button...'
298        self.register('button_press', position_disable)
299
300
301    def region(self):
302        """
303        Use the mouse to get a rectangular region from a plot.
304
305        The return value is [x0, y0, x1, y1] in world coordinates.
306        """
307
308        def region_start(event):
309            height = self.canvas.figure.bbox.height()
310            self.rect = {'fig': None, 'height': height,
311                         'x': event.x, 'y': height - event.y,
312                         'world': [event.xdata, event.ydata,
313                                   event.xdata, event.ydata]}
314            self.register('button_press', None)
315            self.register('motion_notify', region_draw)
316            self.register('button_release', region_disable)
317
318        def region_draw(event):
319            self.canvas._tkcanvas.delete(self.rect['fig'])
320            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
321                                self.rect['x'], self.rect['y'],
322                                event.x, self.rect['height'] - event.y)
323
324        def region_disable(event):
325            self.register('motion_notify', None)
326            self.register('button_release', None)
327
328            self.canvas._tkcanvas.delete(self.rect['fig'])
329
330            self.rect['world'][2:4] = [event.xdata, event.ydata]
331            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
332                self.rect['world'][1], self.rect['world'][2],
333                self.rect['world'][3])
334
335        self.register('button_press', region_start)
336
337        # This has to be modified to block and return the result (currently
338        # printed by region_disable) when that becomes possible in matplotlib.
339
340        return [0.0, 0.0, 0.0, 0.0]
341
342
343    def register(self, type=None, func=None):
344        """
345        Register, reregister, or deregister events of type 'button_press',
346        'button_release', or 'motion_notify'.
347
348        The specified callback function should have the following signature:
349
350            def func(event)
351
352        where event is an MplEvent instance containing the following data:
353
354            name                # Event name.
355            canvas              # FigureCanvas instance generating the event.
356            x      = None       # x position - pixels from left of canvas.
357            y      = None       # y position - pixels from bottom of canvas.
358            button = None       # Button pressed: None, 1, 2, 3.
359            key    = None       # Key pressed: None, chr(range(255)), shift,
360                                  win, or control
361            inaxes = None       # Axes instance if cursor within axes.
362            xdata  = None       # x world coordinate.
363            ydata  = None       # y world coordinate.
364
365        For example:
366
367            def mouse_move(event):
368                print event.xdata, event.ydata
369
370            a = asaplot()
371            a.register('motion_notify', mouse_move)
372
373        If func is None, the event is deregistered.
374
375        Note that in TkAgg keyboard button presses don't generate an event.
376        """
377
378        if not self.events.has_key(type): return
379
380        if func is None:
381            if self.events[type] is not None:
382                # It's not clear that this does anything.
383                self.canvas.mpl_disconnect(self.events[type])
384                self.events[type] = None
385
386                # It seems to be necessary to return events to the toolbar.
387                if type == 'motion_notify':
388                    self.canvas.mpl_connect(type + '_event',
389                        self.figmgr.toolbar.mouse_move)
390                elif type == 'button_press':
391                    self.canvas.mpl_connect(type + '_event',
392                        self.figmgr.toolbar.press)
393                elif type == 'button_release':
394                    self.canvas.mpl_connect(type + '_event',
395                        self.figmgr.toolbar.release)
396
397        else:
398            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
399
400
401    def release(self):
402        """
403        Release buffered graphics.
404        """
405        self.buffering = False
406        self.show()
407
408
409    def save(self, fname=None, orientation=None, dpi=None):
410        """
411        Save the plot to a file.
412
413        fname is the name of the output file.  The image format is determined
414        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
415        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
416        directory.
417        """
418        if fname is None:
419            from datetime import datetime
420            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
421            fname = 'asap'+dstr+'.png'
422
423        d = ['png','.ps','eps']
424
425        from os.path import expandvars
426        fname = expandvars(fname)
427
428        if fname[-3:].lower() in d:
429            try:
430                if fname[-3:].lower() == ".ps":
431                    from matplotlib import __version__ as mv
432                    w = self.figure.figwidth.get()
433                    h = self.figure.figheight.get()
434
435                    if orientation is None:
436                        # auto oriented
437                        if w > h:
438                            orientation = 'landscape'
439                        else:
440                            orientation = 'portrait'
441                    a4w = 8.25
442                    a4h = 11.25
443                    ds = None
444                    if orientation == 'landscape':
445                        ds = min(a4h/w,a4w/h)
446                    else:
447                        ds = min(a4w/w,a4h/h)
448                    ow = ds * w
449                    oh = ds * h
450                    self.figure.set_figsize_inches((ow,oh))
451                    from matplotlib import __version__ as mv
452                    # hack to circument ps bug in eraly versions of mpl
453                    if int(mv.split(".")[1]) < 87:
454                        self.figure.savefig(fname, orientation=orientation)
455                    else:
456                        self.figure.savefig(fname, orientation=orientation,
457                                            papertype="a4")
458                    self.figure.set_figsize_inches((w,h))
459                    print 'Written file %s' % (fname)
460                else:
461                    if dpi is None:
462                        dpi =150
463                    self.figure.savefig(fname,dpi=dpi)
464                    print 'Written file %s' % (fname)
465            except IOError, msg:
466                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
467                return
468        else:
469            print "Invalid image type. Valid types are:"
470            print "'ps', 'eps', 'png'"
471
472
473    def set_axes(self, what=None, *args, **kwargs):
474        """
475        Set attributes for the axes by calling the relevant Axes.set_*()
476        method.  Colour translation is done as described in the doctext
477        for palette().
478        """
479
480        if what is None: return
481        if what[-6:] == 'colour': what = what[:-6] + 'color'
482
483        newargs = {}
484
485        for k, v in kwargs.iteritems():
486            k = k.lower()
487            if k == 'colour': k = 'color'
488            newargs[k] = v
489
490        getattr(self.axes, "set_%s"%what)(*args, **newargs)
491
492        self.show()
493
494
495    def set_figure(self, what=None, *args, **kwargs):
496        """
497        Set attributes for the figure by calling the relevant Figure.set_*()
498        method.  Colour translation is done as described in the doctext
499        for palette().
500        """
501
502        if what is None: return
503        if what[-6:] == 'colour': what = what[:-6] + 'color'
504        #if what[-5:] == 'color' and len(args):
505        #    args = (get_colour(args[0]),)
506
507        newargs = {}
508        for k, v in kwargs.iteritems():
509            k = k.lower()
510            if k == 'colour': k = 'color'
511            newargs[k] = v
512
513        getattr(self.figure, "set_%s"%what)(*args, **newargs)
514        self.show()
515
516
517    def set_limits(self, xlim=None, ylim=None):
518        """
519        Set x-, and y-limits for each subplot.
520
521        xlim = [xmin, xmax] as in axes.set_xlim().
522        ylim = [ymin, ymax] as in axes.set_ylim().
523        """
524        for s in self.subplots:
525            self.axes  = s['axes']
526            self.lines = s['lines']
527            oldxlim =  list(self.axes.get_xlim())
528            oldylim =  list(self.axes.get_ylim())
529            if xlim is not None:
530                for i in range(len(xlim)):
531                    if xlim[i] is not None:
532                        oldxlim[i] = xlim[i]
533            if ylim is not None:
534                for i in range(len(ylim)):
535                    if ylim[i] is not None:
536                        oldylim[i] = ylim[i]
537            self.axes.set_xlim(oldxlim)
538            self.axes.set_ylim(oldylim)
539        return
540
541
542    def set_line(self, number=None, **kwargs):
543        """
544        Set attributes for the specified line, or else the next line(s)
545        to be plotted.
546
547        number is the 0-relative number of a line that has already been
548        plotted.  If no such line exists, attributes are recorded and used
549        for the next line(s) to be plotted.
550
551        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
552
553            import matplotlib
554            help(matplotlib.lines)
555
556        The set_* methods of class Line2D define the attribute names and
557        values.  For non-US usage, "colour" is recognized as synonymous with
558        "color".
559
560        Set the value to None to delete an attribute.
561
562        Colour translation is done as described in the doctext for palette().
563        """
564
565        redraw = False
566        for k, v in kwargs.iteritems():
567            k = k.lower()
568            if k == 'colour': k = 'color'
569
570            if 0 <= number < len(self.lines):
571                if self.lines[number] is not None:
572                    for line in self.lines[number]:
573                        getattr(line, "set_%s"%k)(v)
574                    redraw = True
575            else:
576                if v is None:
577                    del self.attributes[k]
578                else:
579                    self.attributes[k] = v
580
581        if redraw: self.show()
582
583
584    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
585        """
586        Set the panel layout.
587
588        rows and cols, if cols != 0, specify the number of rows and columns in
589        a regular layout.   (Indexing of these panels in matplotlib is row-
590        major, i.e. column varies fastest.)
591
592        cols == 0 is interpreted as a retangular layout that accomodates
593        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
594        rows == 2, cols == 3.
595
596        0 <= n < rows*cols is interpreted as the 0-relative panel number in
597        the configuration specified by rows and cols to be added to the
598        current figure as its next 0-relative panel number (i).  This allows
599        non-regular panel layouts to be constructed via multiple calls.  Any
600        other value of n clears the plot and produces a rectangular array of
601        empty panels.  The number of these may be limited by nplots.
602        """
603        if n < 0 and len(self.subplots):
604            self.figure.clear()
605            self.set_title()
606
607        if rows < 1: rows = 1
608
609        if cols <= 0:
610            i = int(sqrt(rows))
611            if i*i < rows: i += 1
612            cols = i
613
614            if i*(i-1) >= rows: i -= 1
615            rows = i
616
617        if 0 <= n < rows*cols:
618            i = len(self.subplots)
619            self.subplots.append({})
620
621            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
622                                            cols, n+1)
623            self.subplots[i]['lines'] = []
624
625            if i == 0: self.subplot(0)
626
627            self.rows = 0
628            self.cols = 0
629
630        else:
631            self.subplots = []
632
633            if nplots < 1 or rows*cols < nplots:
634                nplots = rows*cols
635            if ganged:
636                hsp,wsp = None,None
637                if rows > 1: hsp = 0.0001
638                if cols > 1: wsp = 0.0001
639                self.figure.subplots_adjust(wspace=wsp,hspace=hsp)
640            for i in range(nplots):
641                self.subplots.append({})
642                self.subplots[i]['axes'] = self.figure.add_subplot(rows,
643                                                cols, i+1)
644                self.subplots[i]['lines'] = []
645
646                if ganged:
647                    # Suppress tick labelling for interior subplots.
648                    if i <= (rows-1)*cols - 1:
649                        if i+cols < nplots:
650                            # Suppress x-labels for frames width
651                            # adjacent frames
652                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
653                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
654                    if i%cols:
655                        # Suppress y-labels for frames not in the left column.
656                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
657                            tick.label1On = False
658                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
659                    # disable the first tick of [1:ncol-1] of the last row
660                    if (nplots-cols) < i <= nplots-1:
661                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
662                self.rows = rows
663                self.cols = cols
664            self.subplot(0)
665
666    def set_title(self, title=None):
667        """
668        Set the title of the plot window.  Use the previous title if title is
669        omitted.
670        """
671        if title is not None:
672            self.title = title
673
674        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
675
676
677    def show(self):
678        """
679        Show graphics dependent on the current buffering state.
680        """
681        if not self.buffering:
682            if self.loc is not None:
683                for j in range(len(self.subplots)):
684                    lines  = []
685                    labels = []
686                    i = 0
687                    for line in self.subplots[j]['lines']:
688                        i += 1
689                        if line is not None:
690                            lines.append(line[0])
691                            lbl = line[0].get_label()
692                            if lbl == '':
693                                lbl = str(i)
694                            labels.append(lbl)
695
696                    if len(lines):
697                        self.subplots[j]['axes'].legend(tuple(lines),
698                                                        tuple(labels),
699                                                        self.loc)
700                    else:
701                        self.subplots[j]['axes'].legend((' '))
702
703
704    def subplot(self, i=None, inc=None):
705        """
706        Set the subplot to the 0-relative panel number as defined by one or
707        more invokations of set_panels().
708        """
709        l = len(self.subplots)
710        if l:
711            if i is not None:
712                self.i = i
713
714            if inc is not None:
715                self.i += inc
716
717            self.i %= l
718            self.axes  = self.subplots[self.i]['axes']
719            self.lines = self.subplots[self.i]['lines']
720
721
722    def text(self, *args, **kwargs):
723        """
724        Add text to the figure.
725        """
726        self.figure.text(*args, **kwargs)
727        self.show()
Note: See TracBrowser for help on using the repository browser.