#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains utility functions for creating symbols and managing symbol contexts.
"""


__author__ = "Martin Sandve Alnes"
__date__   = "2007-10-09 -- 2008-12-12"
__copyright__ = "(C) 2007-2008 Martin Sandve Alnes and Simula Resarch Laboratory"
__license__  = "GNU GPL Version 2, or (at your option) any later version"

import math

import swiginac
import SyFi


class SymbolFactory:
    def __init__(self):
        self.symbols = {}
        for n in ("x", "y", "z", "t", "infinite", "DUMMY"):
            self.symbols[n] = SyFi.get_symbol(n)

    def __call__(self, name):
        if name in self.symbols:
            return self.symbols[name]
        s = swiginac.symbol(name)
        self.symbols[name] = s
        return s

# default context for construction of symbols identified by their name
_symbol_factory = SymbolFactory()


class TempSymbolContext:
    def __init__(self, format="_s%d", symbol_factory=_symbol_factory):
        self.format = format
        self.i = 0
        self.symbol_factory = symbol_factory

    def __call__(self, shape=None):
        name = self.format % self.i
        self.i += 1
        return self.symbol_factory(name)


# Functional user interface to the default symbol factory:

def symbol_exists(name): # Unused?
    """Returns a symbol from the default sfc symbol factory."""
    return name in _symbol_factory.symbols

def symbol(name):
    """Returns a symbol from the default sfc symbol factory."""
    return _symbol_factory(name)

def symbolic_vector(n, name):
    """Returns a length n vector of symbols from the default sfc symbol factory.
       The symbols will be named 'name<i>' with i=0,...,n-1.
       If n > 10, <i> will be prepadded with zeros."""
    digits = 1 if n == 1 else 1 + int(math.log10(n-1))
    rule   = "%%s%%0%dd" % digits
    v      = swiginac.matrix(n, 1)
    for i in xrange(n):
        v[i] = symbol(rule % (name, i))
    return v

def symbolic_matrix(m, n, name):
    """Returns a n x n matrix of symbols from the default sfc symbol factory.
       If m<10  and n<10,  the symbols will be named 'name<i><j>'   with i=0,...,m-1, j=0,...,n-1.
       If m>=10 or  n>=10, the symbols will be named 'name_<i>_<j>'."""
    A = swiginac.matrix(m, n)
    if m > 9 or n > 9:
        padding = "_"
    else:
        padding = ""
    for i in xrange(m):
        for j in xrange(n):
            A[i, j] = symbol("%s%s%d%s%d" % (name, padding, i, padding, j))
    return A

def symbols(names):
    """Returns a list of symbols with names from the default sfc symbol factory.
       'names' must be an iterable sequence of strings."""
    return [symbol(na) for na in names]

