# Some 'fixes' of twodim_base.py functions
from numpy.core.numeric import asanyarray, equal, subtract, arange, \
     zeros, greater_equal, multiply, ones, asarray, alltrue, where, \
     empty

ge= greater_equal
mul= multiply
so= subtract.outer
ar= arange

def tri_(M, N=None, k=0, dtype=float):
    """
    An array with ones at and below the given diagonal and zeros elsewhere.

    Parameters
    ----------
    M : int
        Number of rows in the array.
    N : int, optional
        Number of columns in the array.
        By default, `N` is taken equal to `M`.
    k : int, optional
        The sub-diagonal at and below which the array is filled.
        `k` = 0 is the main diagonal, while `k` < 0 is below it,
        and `k` > 0 is above.  The default is 0.
    dtype : dtype, optional
        Data type of the returned array.  The default is float.

    Returns
    -------
    T : ndarray of shape (M, N)
        Array with its lower triangle filled with ones and zero elsewhere;
        in other words ``T[i,j] == 1`` for ``i <= j + k``, 0 otherwise.

    Examples
    --------
    >>> np.tri(3, 5, 2, dtype=int)
    array([[1, 1, 1, 0, 0],
           [1, 1, 1, 1, 0],
           [1, 1, 1, 1, 1]])

    >>> np.tri(3, 5, -1)
    array([[ 0.,  0.,  0.,  0.,  0.],
           [ 1.,  0.,  0.,  0.,  0.],
           [ 1.,  1.,  0.,  0.,  0.]])

    """
    if N is None:
        N= M
    m= ge(so(ar(M), ar(N)), -k)
    if dtype is bool:
        return m
    return m.astype(dtype)

def tril_(m, k=0):
    """
    Lower triangle of an array.

    Return a copy of an array with elements above the `k`-th diagonal zeroed.

    Parameters
    ----------
    m : array_like, shape (M, N)
        Input array.
    k : int, optional
        Diagonal above which to zero elements.  `k = 0` (the default) is the
        main diagonal, `k < 0` is below it and `k > 0` is above.

    Returns
    -------
    L : ndarray, shape (M, N)
        Lower triangle of `m`, of same shape and data-type as `m`.

    See Also
    --------
    triu : same thing, only for the upper triangle

    Examples
    --------
    >>> np.tril([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1)
    array([[ 0,  0,  0],
           [ 4,  0,  0],
           [ 7,  8,  0],
           [10, 11, 12]])

    """
    m= asanyarray(m)
    out= mul(ge(so(ar(m.shape[0]), ar(m.shape[1])), -k), m)
    if m.dtype is out.dtype:
        return out
    return out.astype(m.dtype)

def triu_(m, k=0):
    """
    Upper triangle of an array.

    Return a copy of a matrix with the elements below the `k`-th diagonal
    zeroed.

    Please refer to the documentation for `tril` for further details.

    See Also
    --------
    tril : lower triangle of an array

    Examples
    --------
    >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1)
    array([[ 1,  2,  3],
           [ 4,  5,  6],
           [ 0,  8,  9],
           [ 0,  0, 12]])

    """
    m= asanyarray(m)
    out= mul(True- ge(so(ar(m.shape[0]), ar(m.shape[1])), -(k- 1)), m)
    if m.dtype is out.dtype:
        return out
    return out.astype(m.dtype)

if __name__ == '__main__':
    """Some tests."""
    from numpy import allclose, dtype, ones, tri, triu, tril

    assert(allclose(tri(7), tri_(7)))
    assert(allclose(tri(7, k= 1), tri_(7, k= 1)))
    assert(allclose(tri(7, k= -1), tri_(7, k= -1)))
    assert(allclose(tri(7, 5, k= -1), tri_(7, 5, k= -1)))
    M= ones((5, 7))
    assert(allclose(triu(M, k= -1), triu_(M, k= -1)))
    assert(allclose(triu(M, k= 1), triu_(M, k= 1)))
    assert(allclose(tril(M, k= -1), tril_(M, k= -1)))
    assert(allclose(tril(M, k= 1), tril_(M, k= 1)))
    M= M.astype(float)
    assert(allclose(triu(M), triu_(M)))
    M= M.astype(bool)
    assert(allclose(triu(M), triu_(M)))
    assert(triu(M).dtype is not dtype(bool))
    assert(triu_(M).dtype is dtype(bool))
    assert(tril_(M).dtype is dtype(bool))
    
