{-
  simplified conjugate gradient, as outlined by Jan-Willem Maessen:

   for j <- seq(1#cgit) do
     q = A p
     alpha = rho / (p DOT q)
     z += alpha p
     rho0 = rho
     r -= alpha q
     rho := r DOT r
     beta = rho / rho0
     p := r + beta p
   end

  Here p,q,r, and z are vectors, A is the derivative of our function
  (in this case a sparse symmetric positive-definite matrix, but we can
  really think of it as a higher-order function of type Vector->Vector)
  and the greek letters are scalars.  The "answer" is z.  In practice
  we'd not run a fixed number of iterations, but instead do a
  convergence test.  All the hard work is in the line "q = A p", but
  the storage consumption is mostly in the surrounding code.
-}

{- ok, lets translate that into Haskell, for concreteness
 - three variants: naive, list-based code
 -                 strict-list-based code
 -                 array-base, inplace-update code
 -}

import Data.Array.Base
import Data.Array.IArray
import Control.Monad(unless)
import Control.Monad.ST
import Data.Array.ST
import Debug.Trace(trace)
import System.Environment(getArgs)

type Elem = Double

n :: Num a => a
n = 40

a :: Matrix
a = [[if i==j then 1 else 0|i<-[1..n]]|j<-[1..n]]

p,r,z :: Vector
p = [1..n]
r = [1..n]
z = [1|i<-[1..n]]

rho :: Elem
rho = 1

-- stop criterion for main loop
stop z n = n==0

------------------------ naive list-based code

type Vector = [Elem]
type Matrix = [Vector]

s   .*   v  = map (s*) v
v1 `add` v2 = zipWith (+) v1 v2
v1 `sub` v2 = zipWith (-) v1 v2
v1 `dot` v2 = sum $ zipWith (*) v1 v2
m  `mat` v  = map (`dot` v) m

loop a p r z rho n | stop z n = z
loop a p r z rho n            = loop a p' r' z' rho' (n-1)
  where q     = a `mat` p
        alpha = rho / (p `dot` q)
        z'    = z `add` (alpha .* p)
        r'    = r `sub` (alpha .* q)
        rho'  = r' `dot` r'
        beta  = rho' / rho
        p'    = r' `add` (beta .* p)

test c = loop a p r z rho c

------------------------ strict list-based code

data StrictList a = Nil | !a :< !(StrictList a) deriving Show

type VectorS = StrictList Elem
type MatrixS = StrictList Vector

foldS f n Nil     = n
foldS f n (x:<xs) = f x (foldS f n xs)

mapS f = foldS ((:<).f) Nil
sumS   = foldS (+) 0

-- avoid intermediate strict list
summulS Nil     Nil     = 0
summulS (a:<as) (b:<bs) = a*b+summulS as bs

zipWithS op (a:<as) (b:<bs) = (a`op`b) :< zipWithS op as bs
zipWithS op Nil     Nil     = Nil

fromList = foldr (:<) Nil
toList   = foldS (:) []

s  `smulS` v = mapS (s*) v
v1 `addS` v2 = zipWithS (+) v1 v2
v1 `subS` v2 = zipWithS (-) v1 v2
v1 `dotS` v2 = summulS v1 v2 -- sumS $ zipWithS (*) v1 v2
m  `matS` v  = mapS (`dotS` v) m

-- avoid intermediate strict list
v1 `addmulS` (x,v2) = zipWithS (\a b->a+x*b) v1 v2
v1 `submulS` (x,v2) = zipWithS (\a b->a-x*b) v1 v2

loopS a p r z rho n | stop z n = z
loopS a p r z rho n            = loopS a p' r' z' rho' (n-1)
  where q     = a `matS` p
        alpha = rho / (p `dotS` q)
        z'    = z `addmulS` (alpha,p) -- z `addS` (alpha `smulS` p)
        r'    = r `submulS` (alpha,q) -- r `subS` (alpha `smulS` q)
        rho'  = r' `dotS` r'
        beta  = rho' / rho
        p'    = r' `addmulS` (beta,p)  -- r' `addS` (beta `smulS` p)

testS c = loopS (fromList (map fromList a)) (fromList p) (fromList r) (fromList z) rho c

------------------------ array-based, update-in-place code

type VectorA s = STUArray s Int Elem
type MatrixA s = STUArray s (Int,Int) Elem

modArray !a !i f = unsafeRead a i >>= (unsafeWrite a i . f)

(+*=),(-*=) :: VectorA s -> (Elem,VectorA s) -> ST s ()
v1 +*= (x,v2) = l v1 x v2 1
  where
    l !v1 !x !v2 !i = unless (i>n) $ do { a<-unsafeRead v1 i;
                                          b<-unsafeRead v2 i;
                                          unsafeWrite v1 i $! (a+(x*b));
                                          l v1 x v2 (i+1) }

v1 -*= (x,v2) = l v1 x v2 1
  where
    l !v1 !x !v2 !i = unless (i>n) $ do { a<-unsafeRead v1 i;
                                          b<-unsafeRead v2 i;
                                          unsafeWrite v1 i $! (a-(x*b));
                                          l v1 x v2 (i+1) }

(*+=) :: (Elem,VectorA s) -> VectorA s -> ST s ()
(x,v1) *+= v2  = x `seq` v1 `seq` v2 `seq` l 1
  where l !i = unless (i>n) $ do { e2 <- unsafeRead v2 i; modArray v1 i ((e2+).(x*)); l (i+1) }

dotA :: VectorA s -> VectorA s -> ST s Elem
v1 `dotA` v2  = v1 `seq` v2 `seq` l 1 0
  where l !i !s | i>n = return s
        l i s       = do { a<-unsafeRead v1 i; b<-unsafeRead v2 i;l (i+1) $! (s+a*b) }

matA :: MatrixA s -> VectorA s -> VectorA s -> ST s (VectorA s)
(m  `matA` v) tmp = m `seq` v `seq` tmp `seq` l 1 1 0
  where l !i !j !s | i>n = return tmp
        l i j s | j>n = unsafeWrite tmp i s >> l (i+1) 1 0
        l i j s       = do a<-unsafeRead m $! (i*(n+1)+j)
                           b<-unsafeRead v j
                           l i (j+1) $! (s+a*b)

loopA a p r z q rho n | stop z n = return z
loopA a p r z q rho n            = do
  (a `matA` p) q
  alpha <- fmap (rho/) (p `dotA` q)
  z +*= (alpha,p)
  r -*= (alpha,q)
  rho'<- r `dotA` r
  let beta  = rho' / rho
  (beta,p) *+= r
  loopA a p r z q rho' (n-1)

testA c = runSTUArray (do
  aA <- newListArray ((0,0),(n,n)) (concat [[if i==j then 1 else 0|i<-[0..n]]|j<-[0..n]])
  pA <- newListArray (0,n) (0:p)
  rA <- newListArray (0,n) (0:r)
  zA <- newListArray (0,n) (0:z)
  qA <- newArray (0,n) 0
  loopA aA pA rA zA qA rho c
  )

-----------------------
main = do
  (version:count:_) <- getArgs
  case version of
    "list"  -> print $ test   (read count) -- 100000: 2m3s
    "listS" -> print $ testS  (read count) -- 100000: 12s
    "array" -> print $ testA  (read count) -- 100000: 33s
    "check" -> do
                let c = read count
                    l = test c
                    ts  = toList $ testS c
                    ea  = tail $ elems $ testA c
                    diff a b = maximum $ map abs $ zipWith (-) a b
                putStrLn $ "list==listS?  "++show (l==ts)++" "++show (diff l ts)
                putStrLn $ "list==array?  "++show (l==ea)++" "++show (diff l ea)
                putStrLn $ "listS==array? "++show (ts==ea)++" "++show (diff ts ea)


