########################################################################
#
# File Name: 	        SQL.py
#
# Documentation:	http://docs.ftsuite.com/4ODS/Drivers/Postgres.py.html
#
"""
Common backend for all SQL'ish databases
WWW: http://4suite.org/4ODS         e-mail: support@4suite.org

Copyright (c) 1999 Fourthought, Inc, USA.   All Rights Reserved.
See  http://4suite.org/COPYRIGHT  for license and copyright information
"""
import sys,string,types, os

import string, time

from Ft.Ods.StorageManager.Adapters import Constants
from Ft.Ods.StorageManager.Adapters import ClassCache

import CompiliedStatement

class DbAdapter:
    statements = {}
    uptoDateStatements = {}
    nameMap = {}

    ################
    #The Transactional Interface

    ################
    def begin(self,db):
        raise "Override"

    def commit(self,db):
        raise "Override"

    def abort(self,db):
        raise "Override"

    def checkpoint(self,db):
        raise "Override"
        

    ############################
    ###Locking
    ############################

    def lock(self,db,repoId,mode):
        raise "Override"


    def try_lock(self,db,repoId,mode):
        raise "Override"

    #############################
    #MISC
    #############################
    def _execute(self,db,st):
        raise "Override"

    def objectUptoDate(self,db,persistentType,id_,testTime):
        rt = self.uptoDateStatements[persistentType].query(db,id=id_,testTime=testTime)
        return len(rt)

    ###########################
    #Python Classes
    ###########################
    def newPythonClass(self,db,klass):
        pid = int(self.statements[CompiliedStatement.NEW_PYTHON_CLASS_ID].query(db)[0][0])
        modName = str(klass.__module__)
        if modName[:8] == '__main__':
            dirs = os.path.splitext(sys.argv[0])[0]
            modName = string.replace(dirs,'/','.') + modName[8:]
        self.statements[CompiliedStatement.NEW_PYTHON_CLASS].execute(db,pid=pid,moduleName=modName, className=klass.__name__)
        return pid

    def getPythonClass(self,db,pid):
        rt = self.statements[CompiliedStatement.GET_PYTHON_CLASS].query(db,pid=pid)
        if len(rt) == 0:
            return None
        return rt[0][0],rt[0][1]

    def deletePythonClass(self,db,pid):
        self.statements[CompiliedStatement.DELETE_PYTHON_CLASS].execute(db,pid=pid)


    def getPythonClassId(self,db,klass):
        #if hasattr(klass,'_4ods_klassId'):
        #    return klass._4ods_klassId
        klassId = ClassCache.g_cache.getClassId(klass)
        if klassId is None:
            modName = str(klass.__module__)
            if modName[:8] == '__main__':
                dirs = os.path.splitext(sys.argv[0])[0]
                modName = string.replace(dirs,'/','.') + modName[8:]
            klassId = self.statements[CompiliedStatement.GET_PYTHON_CLASS_ID].query(db,moduleName=modName, className=klass.__name__)
            if not len(klassId):
                return None
            klassId = int(klassId[0][0])
            ClassCache.g_cache.registerClassId(klassId,klass)
        #klass._4ods_klassId = klassId
        return klassId

    def _4ods_getPythonClassId(self,db,klass):
        cid = self.getPythonClassId(db,klass)
        if cid is None:
            raise "Unknown Class %s, did you register it?" % str(klass)
        return cid



    ###########################
    #Python Classes of Literals
    ###########################

    def newPythonLiteralClass(self,db,typ,repoId,klass):
        modName = str(klass.__module__)
        if modName[:8] == '__main__':
            dirs = os.path.splitext(sys.argv[0])[0]
            modName = string.replace(dirs,'/','.') + modName[8:]

        self.statements[CompiliedStatement.NEW_LITERAL_CLASS].execute(db,moduleName=modName, className=klass.__name__, type=typ,repoid = repoId)

    def modifyPythonLiteralClass(self,db,typ,repoId,klass):
        modName = str(klass.__module__)
        if modName[:8] == '__main__':
            dirs = os.path.splitext(sys.argv[0])[0]
            modName = string.replace(dirs,'/','.') + modName[8:]

        self.statements[CompiliedStatement.UPDATE_LITERAL_CLASS].execute(db,moduleName=modName, className=klass.__name__, type=typ,repoid = repoId)

    def getPythonLiteralClass(self,db,typ,repoId):
        rt = self.statements[CompiliedStatement.GET_LITERAL_CLASS].query(db, type=typ,repoid = repoId)
        if not len(rt):
            return None
        return rt[0]

    def deletePythonLiteralClass(self,db,typ,repoId):
        self.statements[CompiliedStatement.DELETE_LITERAL_CLASS].execute(db, type=typ,repoid = repoId)



    #####################
    #Repository Management
    #####################
    def newRepositoryObjectId(self,db,metaKind,klass,rid = None):
        if rid == None:
            rid = int(self.statements[CompiliedStatement.NEW_REPO_ID].query(db)[0][0])
        if metaKind == 2 and ('key_type',) in klass._tupleNames:
            tableName = 'ftods_dictionary'
        else:
            tableName = self.manager.metaKindToTableName[metaKind]
        klassId = self._4ods_getPythonClassId(db,klass)
        self.statements[CompiliedStatement.NEW_REPO].execute(db,tableName=tableName,rid=rid)
        self.statements[CompiliedStatement.NEW_REPO_MAP].execute(db,metaKind=metaKind,rid=rid,pythonclassid = klassId)
        return rid


    #Repository Object Helpers
    def newInterfaceStorage(self,db,interface):
        data = self._buildInterfaceTable(interface)
        self.manager._createInterfaceStorage(db,data)


    def writeRepositoryObject(self,db,metaKind,rid,types,names,values):
        if metaKind == 2 and ('key_type',) in names:
            tableName = 'ftods_dictionary'
        else:
            tableName = self.manager.metaKindToTableName[metaKind]

        names = map(lambda x,n=self.nameMap:n.get(x,x),names)
        qstr = self._4ods_createUpdateRepoString(tableName,rid,types,names,values)
        if not qstr: return
        self._execute(db,qstr)
        rt = self.statements[CompiliedStatement.UPDATE_REPO_TIME].execute(db,rid=rid,newTime=time.time())
        return 1

    def getRepositoryObject(self,db,id):

        rt = self.statements[CompiliedStatement.GET_REPO].query(db,rid=id)
        if len(rt) == 0:
            return None
        metakind = int(rt[0][0])
        kid = int(rt[0][1])
        klass = ClassCache.g_cache.getClass(self,db,kid)
        if metakind == 2 and ('key_type',) in klass._tupleNames:
            tableName = 'ftods_dictionary'
        else:
            tableName = self.manager.metaKindToTableName[metakind]

        #Now load the tuple
        tup = self.statements[CompiliedStatement.GET_REPO_TUPLE].query(db,tableName=tableName,rid=id)[0]
        return (klass,map(lambda t,v,o=self.toOdmgValues:(o[t[0]](v),),klass._tupleTypes,tup))


    def deleteRepositoryObject(self,db,rid,metaKind):
        tableName = self.manager.metaKindToTableName[metaKind]
        self.statements[CompiliedStatement.DELETE_REPO].execute(db,tableName=tableName,rid = rid)
        self.statements[CompiliedStatement.DELETE_REPO_MAP].execute(db,rid = rid)
        if metaKind == 2:
            self.statements[CompiliedStatement.DELETE_REPO].execute(db,tableName='ftods_dictionary',rid = rid)


    #############################
    #Object Interface
    #############################
    def newObjectId(self,db,klass):
        tableName = 'ftods_interface__%d' % klass._typeId
        oid = int(self.statements[CompiliedStatement.NEW_OBJECT_ID].query(db)[0][0])
        klassId = self._4ods_getPythonClassId(db,klass)
        self.statements[CompiliedStatement.NEW_OBJECT_MAP].execute(db,oid=oid,pythonclassid = klassId)
        self.statements[CompiliedStatement.NEW_OBJECT].execute(db,tableName=tableName,oid=oid)
        return oid

    def writeObject(self,db,oid,typeId,types,names,values):
        # Remove the _oid stuff from the tuples
        qstr = self._4ods_createUpdateObjectString(oid,typeId,types,names,values)
        if not qstr: return
        self._execute(db,qstr)
        rt = self.statements[CompiliedStatement.UPDATE_OBJECT_TIME].execute(db,oid=oid,newTime=time.time())
        return 1
        
    def getObject(self,db,oid):
        rt = self.statements[CompiliedStatement.GET_OBJECT].query(db,oid=oid)
        if len(rt) == 0:
            return None
        kid = int(rt[0][0])
        klass = ClassCache.g_cache.getClass(self,db,kid)
        tableName = 'ftods_interface__%d' % klass._typeId

        # Now load the tuple
        stmt = self.statements[CompiliedStatement.GET_OBJECT_TUPLE]
        # Add the oid to the results
        results = stmt.query(db,tableName=tableName,oid=oid)
        if not results:
            return None
        rt = (klass,self._unflattenValues(results[0],klass._tupleTypes)[0])
        return rt

    def deleteObject(self,db,oid,klass):
        tableName = 'ftods_interface__%d' % klass._typeId
        self.statements[CompiliedStatement.DELETE_OBJECT].execute(db,tableName=tableName,oid = oid)
        self.statements[CompiliedStatement.DELETE_OBJECT_MAP].execute(db,oid = oid)

    def getAllObjectIds(self,db):
        return map(lambda x:int(x[0]),self.statements[CompiliedStatement.GET_OBJECT_IDS].query(db))




    ###########################
    #Extent interface
    ###########################
    def addExtent(self,db,name,persistentType):
        id = self.statements[CompiliedStatement.NEW_EXTENT_ID].query(db)[0][0]
        self.statements[CompiliedStatement.NEW_EXTENT].execute(db,name=name,extentId=id,persistentType=persistentType)
            
    def dropExtent(self,db,name):
        rt = self.statements[CompiliedStatement.GET_EXTENT_ID].query(db,name=name)
        self.statements[CompiliedStatement.DROP_ALL_EXTENT_MAPPING].execute(db,extentId=int(rt[0][0]))
        self.statements[CompiliedStatement.DROP_EXTENT].execute(db,extentId=int(rt[0][0]))

    def addExtentEntry(self,db,names,oid):
        for name in names:
            rt = self.statements[CompiliedStatement.GET_EXTENT_ID].query(db,name=name)
            self.statements[CompiliedStatement.ADD_EXTENT_MAPPING].execute(db,extentId=int(rt[0][0]),id=oid)

    def dropExtentEntry(self,db,names,oid):
        for name in names:
            rt = self.statements[CompiliedStatement.GET_EXTENT_ID].query(db,name=name)
            self.statements[CompiliedStatement.DROP_EXTENT_MAPPING].execute(db,extentId=int(rt[0][0]),id=oid)

    def extentNames(self,db):
        rt = self.statements[CompiliedStatement.GET_EXTENT_NAMES].query(db)
        return map(lambda x:x[0],rt)

    def loadExtent(self,db,name):
        rt = self.statements[CompiliedStatement.GET_EXTENT_ID_AND_TYPE].query(db,name=name)
        if not rt:
            return None
        res = self.statements[CompiliedStatement.GET_EXTENT].query(db,extentId=int(rt[0][0]))
        return (int(rt[0][1]),map(lambda x:int(x[0]),res))
        


    ###########################
    #Binding interface
    ###########################
    def bind(self,db,oid,name):
        self.statements[CompiliedStatement.ADD_BOUND_NAME].execute(db,name=name,oid=oid)

    def unbind(self,db,name):
        #Make sure this name exists in the bindings
        self.statements[CompiliedStatement.DROP_BOUND_NAME].execute(db,name=name)


    def getAllBindings(self,db):
        rt = self.statements[CompiliedStatement.GET_BOUND_NAMES].query(db)
        return map(lambda x:x[0],rt)

    def getObjectBindings(self,db,oid):
        rt = self.statements[CompiliedStatement.GET_OBJECT_BOUND_NAMES].query(db,oid=oid)
        return map(lambda x:x[0],rt)

    def lookup(self,db,name):
        rt = self.statements[CompiliedStatement.GET_BOUND_OBJECT].query(db,name=name)
        if len(rt):
            return int(rt[0][0])
        return None

        

    ############################
    #Collection Interface
    ###########################
    def newCollection(self,db,klass,collType,subType,subTypeRepoId):
        cid = int(self.statements[CompiliedStatement.NEW_COLLECTION_ID].query(db)[0][0])
        #Store its class

        klassId = self._4ods_getPythonClassId(db,klass)
        self.statements[CompiliedStatement.NEW_COLLECTION].execute(db,cid=cid,pythonclassid=klassId,subType=subType,subTypeRepoId=subTypeRepoId)

        return cid

    def writeCollection(self,db,cid,changes,collType,subType):

        tableName = self.manager.collectionTableMapping[subType][0]

        for change in changes:
            if change[0] in [Constants.CollectionChanges.INSERT,Constants.CollectionChanges.APPEND]:
                values = change[1]
                for value,index in values:
                    value = self.toSqlValues[subType](value)
                    if change[0] == Constants.CollectionChanges.INSERT:
                        #Shift all other indexs
                        self.statements[CompiliedStatement.UPDATE_COLLECTION_INDEXES_ADD].execute(db,cid=cid,tableName=tableName,index=index)
                    self.statements[CompiliedStatement.NEW_COLLECTION_ENTRY].execute(db,cid=cid,tableName=tableName,index=index,value=value)
            else:
                #Must be a remove
                for value,index in change[1]:
                    self.statements[CompiliedStatement.DELETE_COLLECTION_ENTRY].execute(db,cid=cid,tableName=tableName,index=index)
                    self.statements[CompiliedStatement.UPDATE_COLLECTION_INDEXES_REMOVE].execute(db,cid=cid,tableName=tableName,index=index)
        rt = self.statements[CompiliedStatement.UPDATE_COLLECTION_TIME].execute(db,cid=cid,newTime=time.time())

        return 1

    def getCollection(self,db,cid):
        rt = self.statements[CompiliedStatement.GET_COLLECTION].query(db,cid=cid)
        if len(rt) == 0:
            return None
        subType = int(rt[0][0])
        subTypeRepoId = int(rt[0][1])
        kid = int(rt[0][2])
        klass = ClassCache.g_cache.getClass(self,db,kid)
        tableName = self.manager.collectionTableMapping[subType][0]
        tup = self.statements[CompiliedStatement.GET_COLLECTION_TUPLE].query(db,cid=cid,tableName=tableName)
        tup = map(lambda v,o=self.toOdmgValues[subType]:o(v[0]),tup)
        res = {}
        for ctr in range(len(tup)):
            res[ctr] = tup[ctr]

        return (klass,subType,subTypeRepoId,res)

    def deleteCollection(self,db,cid,collType,subType):
        tableName = self.manager.collectionTableMapping[subType][0]
        self.statements[CompiliedStatement.DELETE_COLLECTION].execute(db,cid=cid)
        self.statements[CompiliedStatement.DELETE_COLLECTION_TUPLE].execute(db,tableName=tableName,cid=cid)
        


    ############################
    #Dictionary Interface
    ###########################
    def newDictionary(self,db,klass,keyType,keyTypeRepoId,subType,subTypeRepoId):
        did = int(self.statements[CompiliedStatement.NEW_DICTIONARY_ID].query(db)[0][0])
        #Store its class
        klassId = self._4ods_getPythonClassId(db,klass)
        self.statements[CompiliedStatement.NEW_DICTIONARY].execute(db,did=did,pythonclassid=klassId,keyType=keyType,subType=subType,keytyperepoid=keyTypeRepoId,subtyperepoid=subTypeRepoId)
        return did

    def writeDictionary(self,db,did,changes,keyType,subType):
        for change,key,value in changes:
            if type(key) == type(''):
                key = self.toSqlValues[Constants.Types.STRING]('"%s"'%key)[1:-1]
            elif type(key) == type(()):
                key = self.toSqlValues[Constants.Types.STRING](repr(key))[1:-1]
            else:
                key = repr(key)
            if value is not None:
                if type(value)  == type(''):
                    value = self.toSqlValues[Constants.Types.STRING]('"%s"'%value)[1:-1]
                elif type(value) == type(()):
                    value = self.toSqlValues[Constants.Types.STRING](repr(value))[1:-1]
                else:
                    value = repr(value)
                
            if change == Constants.DictionaryChanges.ADD:
                self.statements[CompiliedStatement.NEW_DICTIONARY_ENTRY].execute(db,did=did,key=key,value=value)
            elif change ==  Constants.DictionaryChanges.CHANGE:
                if value == 0:
                    raise Foo
                self.statements[CompiliedStatement.UPDATE_DICTIONARY_ENTRY].execute(db,did=did,key=key,value=value)
            else:
                #Must be a remove
                self.statements[CompiliedStatement.DELETE_DICTIONARY_ENTRY].execute(db,did=did,key=key)

        rt = self.statements[CompiliedStatement.UPDATE_DICTIONARY_TIME].execute(db,did=did,newTime=time.time())
        return 1

    def getDictionary(self,db,did):
        rt = self.statements[CompiliedStatement.GET_DICTIONARY].query(db,did=did)
        if len(rt) == 0:
            return None
        keyType = int(rt[0][0])
        keyTypeRepoId = int(rt[0][1])
        subType = int(rt[0][2])
        subTypeRepoId = int(rt[0][3])
        kid = int(rt[0][4])
        klass = ClassCache.g_cache.getClass(self,db,kid)
        tup = self.statements[CompiliedStatement.GET_DICTIONARY_TUPLE].query(db,did=did)
        res = {}
        for name,value in tup:
                
            name = eval(name)
            value = eval(value)
            res[name] = value
        return (klass,keyType,keyTypeRepoId,subType,subTypeRepoId,res)

    def deleteDictionary(self,db,did,keyType,subType):
        self.statements[CompiliedStatement.DELETE_DICTIONARY].execute(db,did=did)
        self.statements[CompiliedStatement.DELETE_DICTIONARY_TUPLE].execute(db,did=did)
        

        


    #######################################
    #Blobs
    #######################################
    def newBlob(self,db):
        raise "Please Override"
    def readBlob(self,db,bid):
        raise "Please Override"
    def writeBlob(self,db,oid,data):
        raise "Please Overide"
    def deleteBlob(self,db,oid):
        raise "Please Override"
    



    #############################
    #Literal Interface
    #############################


    #Garbage Collection
    def collectGarbage(self):
        self._4ods_collectRepositoryObjectGarbage()





    ##########################
    #Internal Interfaces
    ##########################
    def _4ods_createUpdateRepoString(self,tableName,rid,types,names,values):
        qstr = ''
        for ctr in range(len(names)):
            name = names[ctr][0]
            if name == '_oid':
                continue

            name = self.nameMap.get(name,name)
            if len(values[ctr]) == 2:
                v = values[ctr][1]
            else:
                v = values[ctr][0]
            if v != None:
                qstr = qstr + name + '=' + self.toSqlValues[types[ctr][0]](v) + ','

        if not qstr: return
        return 'UPDATE %s SET %s WHERE _repoid = %i' % (tableName,
                                                        qstr[:-1],
                                                        rid
                                                        )

    def _4ods_createUpdateObjectString(self,oid,typeId,types,names,values):

        names,types,values = self._flattenNamesTypesAndValues(names,types,values)

        qstr = ''
        for ctr in range(len(names)):
            if values[ctr] != None:
                if names[ctr] == '_oid':
                    continue
                qstr = qstr + names[ctr] + '=' + self.toSqlValues[types[ctr]](values[ctr]) + ','
                    
        if not qstr: return
        tableName = 'ftods_interface__%i' % typeId
        qstr = 'UPDATE %s SET %s WHERE _oid = %i' % (tableName,
                                                     qstr[:-1],
                                                     oid
                                                     )
        return qstr


    def _buildInterfaceTable(self,interface,foundNames = None):
        foundNames = foundNames or []
        tableName = "ftods_interface__%i"%interface._4ods_getId()
        if tableName in foundNames: return None
        foundNames.append(tableName)
        defs = []

        from Ft.Ods.Parsers.Odl import OdlUtil
        names,types = OdlUtil.BuildNameAndTypeList(interface.defines)

        names,types = self._flattenNamesAndTypes(names,types)


        for name,t in map(None,names,types):
            defs.append((name,self.manager.odmgToSqlTypes[t._4ods_getOdmgType()](t)))
       
        bClasses = []
        for inh in interface.inherits:
            rt = self._buildInterfaceTable(inh,foundNames)
            if rt:
                bClasses.append(rt)
        if interface.meta_kind._v == 1 and interface.extender:
            rt = self._buildInterfaceTable(interface.extender,foundNames)
            if rt:
                bClasses.append(rt)
        if len(bClasses) == 0 and 'ftods_object' not in foundNames:
            data = map(lambda x,o=self.manager.systemTypesToSql:(x[0],o[x[1]]),self.manager.tables['ftods_object'][self.manager.TableData.TABLE_INIT])
            bClasses.append(('ftods_object',data,[]))
            foundNames.append('ftods_object')
        return (tableName,defs,bClasses)


    def _flattenNamesAndTypes(self,names,types,parentName=''):

        """Names and types are tuples.  Each entry in the tuple is another tuple
        If the tuple is of len(1) then it is a leaf
        If the tuple is of length 2, then it is not a leaf
        """
        
        newNames = []
        newTypes = []
        for ctr in range(len(names)):
            name = names[ctr]
            t = types[ctr]

            if len(t) == 2:
                #A nested definition
                ourName = name[0] + '_'
                nn,nt = self._flattenNamesAndTypes(name[1],t[1],parentName = parentName + ourName)
                newNames.extend(list(nn))
                newTypes.extend(list(nt))
            else:
                newNames.append(parentName + name[0])
                newTypes.append(t[0])
        return tuple(newNames),tuple(newTypes)

    def _flattenNamesTypesAndValues(self,names,types,values,parentName=''):

        """Names types and values are tuples.  Each entry in the tuple is another tuple or None
        If the value is None, then no more processing is done at this branch (it has been pruned)
        If the tuple is of len(1) then it is a leaf
        If the tuple is of length 2, then it is not a leaf
        """
        
        newNames = []
        newTypes = []
        newValues = []

        for ctr in range(len(values)):
            name = names[ctr]
            t = types[ctr]
            v = values[ctr]
            if v is None:
                continue

            if len(t) == 2:
                #A nested definition
                ourName = name[0] + '_'
                nn,nt,nv = self._flattenNamesTypesAndValues(name[1],t[1],v[1],parentName = parentName + ourName)
                newNames.extend(list(nn))
                newTypes.extend(list(nt))
                newValues.extend(list(nv))
            else:
                newNames.append(parentName + name[0])
                newTypes.append(t[0])
                newValues.append(v[0])
        return tuple(newNames),tuple(newTypes), tuple(newValues)




    def _unflattenValues(self,values,types):

        newValues = []
        vCtr = 0
        for ctr in range(len(types)):
            t = types[ctr]
            if len(t) == 1:
                # a single value
                if values[vCtr] is not None:
                    newValues.append((self.toOdmgValues[t[0]](values[vCtr]),))
                else:
                    newValues.append(None)
                vCtr = vCtr + 1
            else:
                #a nested literal
                nv,used = self._unflattenValues(values[vCtr:],t[1])
                newValues.append((None,nv))
                vCtr = vCtr + used
        return tuple(newValues),vCtr
        

    ####################################
    #
    # Operations
    #
    ####################################
    
    def registerOperation(self,db,repoId,functionName,functionModule):
        res = self.statements[CompiliedStatement.GET_OPERATION].query(db,repoId=repoId)

        if res:
            #This is an update
            oldFunctionName,blobId = res[0]

            self.writeBlob(db,int(blobId),functionModule)
            self.statements[CompiliedStatement.UPDATE_OPERATION].execute(db,functionName=functionName,blobId=int(blobId),repoId=repoId)
        else:
            #This is a new one
            blobId = self.newBlob(db)
            self.writeBlob(db,blobId,functionModule)
            self.statements[CompiliedStatement.INSERT_OPERATION].execute(db,functionName=functionName,blobId=blobId,repoId=repoId)
           

    def getOperation(self,db,repoId):
        res = self.statements[CompiliedStatement.GET_OPERATION].query(db,repoId=repoId)
        if not res:
            return None
        functionName,blobId = res[0]
        functionModule = self.readBlob(db,int(blobId))
        return functionName,functionModule
        
