{-
  simplified conjugate gradient, as outlined by Jan-Willem Maessen:

   for j <- seq(1#cgit) do
     q = A p
     alpha = rho / (p DOT q)
     z += alpha p
     rho0 = rho
     r -= alpha q
     rho := r DOT r
     beta = rho / rho0
     p := r + beta p
   end

  Here p,q,r, and z are vectors, A is the derivative of our function  
  (in this case a sparse symmetric positive-definite matrix, but we can  
  really think of it as a higher-order function of type Vector->Vector)  
  and the greek letters are scalars.  The "answer" is z.  In practice  
  we'd not run a fixed number of iterations, but instead do a  
  convergence test.  All the hard work is in the line "q = A p", but  
  the storage consumption is mostly in the surrounding code.  
-}

{- ok, lets translate that into Haskell, for concreteness 
 - three variants: naive, list-based code
 -                 strict-list-based code
 -                 array-base, inplace-update code
 -}

import Data.Array.IArray
import Control.Monad(unless)
import Control.Monad.ST
import Data.Array.ST
import Debug.Trace(trace)
import System.Environment(getArgs)

type Elem = Double

n :: Num a => a
n = 40

a :: Matrix
a = [[if i==j then 1 else 0|i<-[1..n]]|j<-[1..n]]

p,r,z :: Vector 
p = [1..n]
r = [1..n]
z = [1|i<-[1..n]]

rho :: Elem
rho = 1

-- stop criterion for main loop
stop z n = n==0

------------------------ naive list-based code

type Vector = [Elem]
type Matrix = [Vector]

s   .*   v  = map (s*) v
v1 `add` v2 = zipWith (+) v1 v2
v1 `sub` v2 = zipWith (-) v1 v2
v1 `dot` v2 = sum $ zipWith (*) v1 v2
m  `mat` v  = map (`dot` v) m

loop a p r z rho n | stop z n = z 
loop a p r z rho n            = loop a p' r' z' rho' (n-1)
  where q     = a `mat` p
        alpha = rho / (p `dot` q)
        z'    = z `add` (alpha .* p)
        r'    = r `sub` (alpha .* q)
        rho'  = r' `dot` r'
        beta  = rho' / rho
        p'    = r' `add` (beta .* p)

test c = loop a p r z rho c

------------------------ strict list-based code

data StrictList a = Nil | !a :< !(StrictList a) deriving Show

type VectorS = StrictList Elem
type MatrixS = StrictList Vector

foldS f n Nil     = n
foldS f n (x:<xs) = f x (foldS f n xs)

mapS f = foldS ((:<).f) Nil
sumS   = foldS (+) 0

-- avoid intermediate strict list
summulS Nil     Nil     = 0
summulS (a:<as) (b:<bs) = a*b+summulS as bs

zipWithS op (a:<as) (b:<bs) = (a`op`b) :< zipWithS op as bs
zipWithS op Nil     Nil     = Nil

fromList = foldr (:<) Nil
toList   = foldS (:) []

s  `smulS` v = mapS (s*) v
v1 `addS` v2 = zipWithS (+) v1 v2
v1 `subS` v2 = zipWithS (-) v1 v2
v1 `dotS` v2 = summulS v1 v2 -- sumS $ zipWithS (*) v1 v2
m  `matS` v  = mapS (`dotS` v) m

-- avoid intermediate strict list
v1 `addmulS` (x,v2) = zipWithS (\a b->a+x*b) v1 v2
v1 `submulS` (x,v2) = zipWithS (\a b->a-x*b) v1 v2

loopS a p r z rho n | stop z n = z 
loopS a p r z rho n            = loopS a p' r' z' rho' (n-1)
  where q     = a `matS` p
        alpha = rho / (p `dotS` q)
        z'    = z `addmulS` (alpha,p) -- z `addS` (alpha `smulS` p)
        r'    = r `submulS` (alpha,q) -- r `subS` (alpha `smulS` q)
        rho'  = r' `dotS` r'
        beta  = rho' / rho
        p'    = r' `addmulS` (beta,p)  -- r' `addS` (beta `smulS` p)

testS c = loopS (fromList (map fromList a)) (fromList p) (fromList r) (fromList z) rho c

------------------------ array-based, update-in-place code

type VectorA s = STUArray s Int Elem
type MatrixA s = STUArray s (Int,Int) Elem

modArray a i f = readArray a i >>= (writeArray a i . f)
l v1 x v2 i op = unless (i>n) $ do { a<-readArray v1 i; 
                                     b<-readArray v2 i; 
                                     writeArray v1 i $! (a`op`(x*b)); 
                                     l v1 x v2 (i+1) op }

(+*=),(-*=) :: VectorA s -> (Elem,VectorA s) -> ST s ()
v1 +*= (x,v2) = l v1 x v2 1 (+)
v1 -*= (x,v2) = l v1 x v2 1 (-)

(*+=) :: (Elem,VectorA s) -> VectorA s -> ST s ()
(x,v1) *+= v2  = l 1 
  where l i = unless (i>n) $ do { e2 <- readArray v2 i; modArray v1 i ((e2+).(x*)); l (i+1) }

dotA :: VectorA s -> VectorA s -> ST s Elem
v1 `dotA` v2  = l 1 0 
  where l i s | i>n = return s
        l i s       = do { a<-readArray v1 i; b<-readArray v2 i;l (i+1) $! (s+a*b) }

matA :: MatrixA s -> VectorA s -> VectorA s -> ST s (VectorA s)
(m  `matA` v) tmp = l 1 1 0 
  where l i j s | i>n = return tmp
        l i j s | j>n = writeArray tmp i s >> l (i+1) 1 0
        l i j s       = do a<-readArray m (i,j)
                           b<-readArray v j
                           l i (j+1) $! (s+a*b)

loopA a p r z q rho n | stop z n = return z 
loopA a p r z q rho n            = do 
  (a `matA` p) q
  alpha <- fmap (rho/) (p `dotA` q)
  z +*= (alpha,p)
  r -*= (alpha,q)
  rho'<- r `dotA` r
  let beta  = rho' / rho
  (beta,p) *+= r
  loopA a p r z q rho' (n-1)

testA c = runSTUArray (do 
  aA <- newListArray ((1,1),(n,n)) (concat a)
  pA <- newListArray (1,n) p
  rA <- newListArray (1,n) r
  zA <- newListArray (1,n) z
  qA <- newArray (1,n) 0
  loopA aA pA rA zA qA rho c
  )

-----------------------
main = do
  (version:count:_) <- getArgs
  case version of
    "list"  -> print $ test   (read count) -- 100000: 2m3s
    "listS" -> print $ testS  (read count) -- 100000: 12s 
    "array" -> print $ testA  (read count) -- 100000: 33s
    "check" -> do
                let c = read count
                    l = test c
                    ts  = toList $ testS c
                    ea  = elems $ testA c
                    diff a b = maximum $ map abs $ zipWith (-) a b
                putStrLn $ "list==listS?  "++show (l==ts)++" "++show (diff l ts)
                putStrLn $ "list==array?  "++show (l==ea)++" "++show (diff l ea)
                putStrLn $ "listS==array? "++show (ts==ea)++" "++show (diff ts ea)

