module Database.HSQL.ODBC(connect, driverConnect, module Database.HSQL) where
import Database.HSQL
import Database.HSQL.Types
import Data.Word(Word32, Word16)
import Data.Int(Int32, Int16)
import Data.Maybe
import Foreign
import Foreign.C
import Control.Monad(unless)
import Control.OldException(throwDyn)
import Control.Concurrent.MVar
import System.IO.Unsafe
import System.Time
type SQLHANDLE = Ptr ()
type HENV = SQLHANDLE
type HDBC = SQLHANDLE
type HSTMT = SQLHANDLE
type HENVRef = ForeignPtr ()
type SQLSMALLINT = Int16
type SQLUSMALLINT = Word16
type SQLINTEGER = Int32
type SQLUINTEGER = Word32
type SQLRETURN = SQLSMALLINT
type SQLLEN = SQLINTEGER
type SQLULEN = SQLINTEGER
foreign import ccall "HsODBC.h SQLAllocEnv"
sqlAllocEnv:: Ptr HENV -> IO SQLRETURN
foreign import ccall "HsODBC.h &SQLFreeEnv"
sqlFreeEnv_p:: FunPtr (HENV -> IO ())
foreign import ccall "HsODBC.h SQLAllocConnect"
sqlAllocConnect:: HENV -> Ptr HDBC -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLFreeConnect"
sqlFreeConnect:: HDBC -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLConnect"
sqlConnect:: HDBC -> CString -> Int -> CString -> Int -> CString -> Int
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLDriverConnect"
sqlDriverConnect:: HDBC -> Ptr () -> CString -> SQLSMALLINT -> CString
-> SQLSMALLINT -> Ptr SQLSMALLINT -> SQLUSMALLINT
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLDisconnect"
sqlDisconnect:: HDBC -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLAllocStmt"
sqlAllocStmt:: HDBC -> Ptr HSTMT -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLFreeStmt"
sqlFreeStmt:: HSTMT -> SQLUSMALLINT -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLNumResultCols"
sqlNumResultCols:: HSTMT -> Ptr SQLUSMALLINT -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLDescribeCol"
sqlDescribeCol:: HSTMT -> SQLUSMALLINT -> CString -> SQLSMALLINT
-> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLULEN
-> Ptr SQLSMALLINT -> Ptr SQLSMALLINT
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLBindCol"
sqlBindCol:: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr a
-> SQLLEN -> Ptr SQLINTEGER
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLFetch"
sqlFetch:: HSTMT -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLGetDiagRec"
sqlGetDiagRec:: SQLSMALLINT -> SQLHANDLE -> SQLSMALLINT -> CString
-> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLExecDirect"
sqlExecDirect:: HSTMT -> CString -> Int -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLSetConnectOption"
sqlSetConnectOption:: HDBC -> SQLUSMALLINT -> SQLULEN -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLTransact"
sqlTransact:: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN
foreign import ccall "HsODBC.h SQLGetData"
sqlGetData:: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr ()
-> SQLINTEGER -> Ptr SQLINTEGER
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLTables"
sqlTables:: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT
-> CString -> SQLSMALLINT -> CString -> SQLSMALLINT
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLColumns"
sqlColumns:: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT
-> CString -> SQLSMALLINT -> CString -> SQLSMALLINT
-> IO SQLRETURN
foreign import ccall "HsODBC.h SQLMoreResults"
sqlMoreResults:: HSTMT -> IO SQLRETURN
handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
| res == (0) || res == (100) = return ()
| res == (1) = do
return ()
| res == (2) = throwDyn SqlInvalidHandle
| res == (2) = throwDyn SqlStillExecuting
| res == (99) = throwDyn SqlNeedData
| res == (1) = do
e <- getSqlError
throwDyn e
| otherwise = error (show res)
where
getSqlError =
allocaBytes 256 $ \pState ->
alloca $ \pNative ->
allocaBytes 256 $ \pMsg ->
alloca $ \pTextLen -> do
res <- sqlGetDiagRec handleType handle 1
pState pNative pMsg 256 pTextLen
if res == (100)
then return SqlNoData
else do
state <- peekCString pState
native <- peek pNative
msg <- peekCString pMsg
return (SqlError { seState=state
, seNativeError=fromIntegral native
, seErrorMsg=msg })
myEnvironment :: HENVRef
myEnvironment = unsafePerformIO $ alloca $ \ (phEnv :: Ptr HENV) -> do
res <- sqlAllocEnv phEnv
hEnv <- peek phEnv
handleSqlResult 0 nullPtr res
newForeignPtr sqlFreeEnv_p hEnv
connect :: String
-> String
-> String
-> IO Connection
connect server user authentication = connectHelper $ \hDBC ->
withCString server $ \pServer ->
withCString user $ \pUser ->
withCString authentication $ \pAuthentication ->
sqlConnect hDBC
pServer (3)
pUser (3)
pAuthentication (3)
driverConnect :: String
-> IO Connection
driverConnect connString = connectHelper $ \hDBC ->
withCString connString $ \pConnString ->
allocaBytes 1024 $ \pOutConnString ->
alloca $ \pLen ->
sqlDriverConnect hDBC nullPtr pConnString (3) pOutConnString 1024 pLen (0)
connectHelper :: (HDBC -> IO SQLRETURN) -> IO Connection
connectHelper connectFunction = withForeignPtr myEnvironment $ \hEnv -> do
hDBC <- alloca $ \ (phDBC :: Ptr HDBC) -> do
res <- sqlAllocConnect hEnv phDBC
handleSqlResult (1) hEnv res
peek phDBC
res <- connectFunction hDBC
handleSqlResult (2) hDBC res
refFalse <- newMVar False
let connection = (Connection
{ connDisconnect = disconnect hDBC
, connExecute = execute hDBC
, connQuery = query connection hDBC
, connTables = tables connection hDBC
, connDescribe = describe connection hDBC
, connBeginTransaction = beginTransaction myEnvironment hDBC
, connCommitTransaction = commitTransaction myEnvironment hDBC
, connRollbackTransaction = rollbackTransaction myEnvironment hDBC
, connClosed = refFalse
})
return connection
where
disconnect :: HDBC -> IO ()
disconnect hDBC = do
sqlDisconnect hDBC >>= handleSqlResult (2) hDBC
sqlFreeConnect hDBC >>= handleSqlResult (2) hDBC
execute :: HDBC -> String -> IO ()
execute hDBC query = allocaBytes (8) $
\pStmt -> do
res <- sqlAllocStmt hDBC pStmt
handleSqlResult (2) hDBC res
hSTMT <- peek pStmt
withCStringLen query $ \(pQuery,len) -> do
res <- sqlExecDirect hSTMT pQuery len
handleSqlResult (3) hSTMT res
res <- sqlFreeStmt hSTMT (1)
handleSqlResult (3) hSTMT res
stmtBufferSize = 256
withStatement :: Connection -> HDBC -> (HSTMT -> IO SQLRETURN) -> IO Statement
withStatement connection hDBC f =
allocaBytes (288) $ \pFIELD -> do
res <- sqlAllocStmt hDBC (((\hsc_ptr -> hsc_ptr `plusPtr` 0)) pFIELD)
handleSqlResult (2) hDBC res
hSTMT <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) pFIELD
let handleResult res = handleSqlResult (3) hSTMT res
f hSTMT >>= handleResult
fields <- moveToFirstResult hSTMT pFIELD
buffer <- mallocBytes (fromIntegral stmtBufferSize)
refFalse <- newMVar False
let statement = Statement
{ stmtConn = connection
, stmtClose = closeStatement hSTMT buffer
, stmtFetch = fetch hSTMT
, stmtGetCol = getColValue hSTMT buffer
, stmtFields = fields
, stmtClosed = refFalse
}
return statement
where
moveToFirstResult :: HSTMT -> Ptr a -> IO [FieldDef]
moveToFirstResult hSTMT pFIELD = do
res <- sqlNumResultCols hSTMT (((\hsc_ptr -> hsc_ptr `plusPtr` 8)) pFIELD)
handleSqlResult (3) hSTMT res
count <- ((\hsc_ptr -> peekByteOff hsc_ptr 8)) pFIELD
if count == 0
then do
res <- sqlMoreResults hSTMT
handleSqlResult (3) hSTMT res
if res == (100)
then return []
else moveToFirstResult hSTMT pFIELD
else
getFieldDefs hSTMT pFIELD 1 count
getFieldDefs :: HSTMT -> Ptr a -> SQLUSMALLINT -> SQLUSMALLINT -> IO [FieldDef]
getFieldDefs hSTMT pFIELD n count
| n > count = return []
| otherwise = do
res <- sqlDescribeCol hSTMT n (((\hsc_ptr -> hsc_ptr `plusPtr` 10)) pFIELD) (255) (((\hsc_ptr -> hsc_ptr `plusPtr` 266)) pFIELD) (((\hsc_ptr -> hsc_ptr `plusPtr` 268)) pFIELD) (((\hsc_ptr -> hsc_ptr `plusPtr` 272)) pFIELD) (((\hsc_ptr -> hsc_ptr `plusPtr` 280)) pFIELD) (((\hsc_ptr -> hsc_ptr `plusPtr` 282)) pFIELD)
handleSqlResult (3) hSTMT res
name <- peekCString (((\hsc_ptr -> hsc_ptr `plusPtr` 10)) pFIELD)
dataType <- ((\hsc_ptr -> peekByteOff hsc_ptr 268)) pFIELD
columnSize <- ((\hsc_ptr -> peekByteOff hsc_ptr 272)) pFIELD
decimalDigits <- ((\hsc_ptr -> peekByteOff hsc_ptr 280)) pFIELD
(nullable :: SQLSMALLINT) <- ((\hsc_ptr -> peekByteOff hsc_ptr 282)) pFIELD
let sqlType = mkSqlType dataType columnSize decimalDigits
fields <- getFieldDefs hSTMT pFIELD (n+1) count
return ((name,sqlType,toBool nullable):fields)
mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> SqlType
mkSqlType (1) size _ = SqlChar (fromIntegral size)
mkSqlType (12) size _ = SqlVarChar (fromIntegral size)
mkSqlType (1) size _ = SqlLongVarChar (fromIntegral size)
mkSqlType (3) size prec = SqlDecimal (fromIntegral size) (fromIntegral prec)
mkSqlType (2) size prec = SqlNumeric (fromIntegral size) (fromIntegral prec)
mkSqlType (5) _ _ = SqlSmallInt
mkSqlType (4) _ _ = SqlInteger
mkSqlType (7) _ _ = SqlReal
mkSqlType (6) _ _ = SqlFloat
mkSqlType (8) _ _ = SqlDouble
mkSqlType (7) _ _ = SqlBit
mkSqlType (6) _ _ = SqlTinyInt
mkSqlType (5) _ _ = SqlBigInt
mkSqlType (2) size _ = SqlBinary (fromIntegral size)
mkSqlType (3) size _ = SqlVarBinary (fromIntegral size)
mkSqlType (4)size _ = SqlLongVarBinary (fromIntegral size)
mkSqlType (9) _ _ = SqlDate
mkSqlType (10) _ _ = SqlTime
mkSqlType (11) _ _ = SqlDateTime
mkSqlType (8) size _ = SqlWChar (fromIntegral size)
mkSqlType (9) size _ = SqlWVarChar (fromIntegral size)
mkSqlType (10) size _ = SqlWLongVarChar (fromIntegral size)
mkSqlType tp _ _ = SqlUnknown (fromIntegral tp)
query :: Connection -> HDBC -> String -> IO Statement
query connection hDBC q = withStatement connection hDBC doQuery
where doQuery hSTMT = withCStringLen q (uncurry (sqlExecDirect hSTMT))
beginTransaction myEnvironment hDBC = do
sqlSetConnectOption hDBC (102) (0)
return ()
commitTransaction myEnvironment hDBC = withForeignPtr myEnvironment $ \hEnv -> do
sqlTransact hEnv hDBC (0)
sqlSetConnectOption hDBC (102) (1)
return ()
rollbackTransaction myEnvironment hDBC = withForeignPtr myEnvironment $ \hEnv -> do
sqlTransact hEnv hDBC (1)
sqlSetConnectOption hDBC (102) (1)
return ()
tables :: Connection -> HDBC -> IO [String]
tables connection hDBC = do
stmt <- withStatement connection hDBC sqlTables'
collectRows (\s -> getFieldValue s "TABLE_NAME") stmt
where sqlTables' hSTMT = sqlTables hSTMT nullPtr 0 nullPtr 0 nullPtr 0 nullPtr 0
describe :: Connection -> HDBC -> String -> IO [FieldDef]
describe connection hDBC table = do
stmt <- withStatement connection hDBC (sqlColumns' table)
collectRows getColumnInfo stmt
where
sqlColumns' table hSTMT =
withCStringLen table (\(pTable,len) ->
sqlColumns hSTMT nullPtr 0 nullPtr 0 pTable (fromIntegral len) nullPtr 0)
getColumnInfo stmt = do
column_name <- getFieldValue stmt "COLUMN_NAME"
(data_type::Int) <- getFieldValue stmt "DATA_TYPE"
(column_size::Int) <- getFieldValue' stmt "COLUMN_SIZE" 0
(decimal_digits::Int) <- getFieldValue' stmt "DECIMAL_DIGITS" 0
(nullable::Int) <- getFieldValue stmt "NULLABLE"
let sqlType = mkSqlType (fromIntegral data_type) (fromIntegral column_size) (fromIntegral decimal_digits)
return (column_name, sqlType, toBool nullable)
fetch :: HSTMT -> IO Bool
fetch hSTMT = do
res <- sqlFetch hSTMT
handleSqlResult (3) hSTMT res
return (res /= (100))
getColValue :: HSTMT -> CString -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue hSTMT buffer colNumber fieldDef f = do
(res,len_or_ind) <- getData buffer (fromIntegral stmtBufferSize)
if len_or_ind == (1)
then f fieldDef nullPtr 0
else if res == (1)
then getLongData len_or_ind
else f fieldDef buffer (fromIntegral len_or_ind)
where
getData :: CString -> SQLINTEGER -> IO (SQLRETURN, SQLINTEGER)
getData buffer size = alloca $ \lenP -> do
res <- sqlGetData hSTMT (fromIntegral colNumber+1) (1) (castPtr buffer) size lenP
handleSqlResult (3) hSTMT res
len_or_ind <- peek lenP
return (res, len_or_ind)
getLongData len = allocaBytes (fromIntegral newBufSize) $ \newBuf -> do
copyBytes newBuf buffer stmtBufferSize
let newDataStart = newBuf `plusPtr` (stmtBufferSize 1)
newDataLen = newBufSize (fromIntegral stmtBufferSize 1)
(res,_) <- getData newDataStart newDataLen
f fieldDef newBuf (fromIntegral newBufSize1)
where
newBufSize = len+1
closeStatement :: HSTMT -> CString -> IO ()
closeStatement hSTMT buffer = do
free buffer
sqlFreeStmt hSTMT (1) >>= handleSqlResult (3) hSTMT