from __future__ import division
import ctypes as C, os
import global_variables as GV
from numpy import *

# always use new-style classes:
__metaclass__ = type

# NOTE: apparently on opteron int is int32 (in C)
# FIXME: is char always uint8?

RS = C.CDLL (os.path.join (GV.ONSET_root,
                           "librectify%s.so" % GV.host_name))

flags = 'aligned, contiguous, writeable'
RS.solve.argtypes = [C.c_int32,
                     ctypeslib.ndpointer (double, ndim = 2, flags = flags),
                     ctypeslib.ndpointer (double, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (double, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (double, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (double, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (double, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (intp, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (intp, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (uint8, ndim = 1, flags = flags),
                     ctypeslib.ndpointer (int32, ndim = 1, flags = flags),
                     C.c_int32]

N = 1024

def alloc ():
    """allocate workspace for C code
    """
    global W_shrunk, LU, s_shrunk, m_shrunk, inds, sst_ws, ch_ws, pivot
    
    W_shrunk = ndarray ((N * N,), double)
    LU = ndarray ((N * N,), double)
    s_shrunk = ndarray ((N,), double)
    m_shrunk = ndarray ((N,), double)
    inds = ndarray ((N,), intp)
    sst_ws = ndarray ((N,), intp)
    ch_ws = ndarray ((N,), uint8)
    pivot = ndarray ((N + 1,), int32) # needs to hold 'info', too

def rec_solve (W, s, max_iter = 1000):
    global N, W_shrunk, LU, s_shrunk, m_shrunk, inds, sst_ws, ch_ws, pivot
    n = len (s)
    assert W.shape == (n, n)
    if n > N:
        N = n
        alloc ()
    NN = RS.solve (n, W, s, W_shrunk, LU, s_shrunk, m_shrunk, inds, sst_ws,
                   ch_ws, pivot, max_iter)
    if NN < 0:
        if NN == -1:
            raise RuntimeError, "solver failed (max_iter)"
        elif NN == -2:
            raise RuntimeError, \
                  "solver failed (LU factorisation (%s))" % pivot [n]
        elif NN == -3:
            raise RuntimeError, "solver failed (LU solver (%s))" % pivot [n]
        else:
            raise RuntimeError, "solver failed (unknown error (%s)!)" % NN
    no_iter = max_iter - pivot [n]
    i = inds [:NN].copy ()
    #m_sh = m_shrunk [:NN].copy ()
    m = zeros_like (s)
    m [i] = m_shrunk [:NN]
    W_sh = W_shrunk [:NN*NN].copy ()
    W_sh.shape = NN, NN
    return i, W_sh, m, no_iter

alloc ()
