{-# OPTIONS_GHC -O2 -funbox-strict-fields -optc-O3 -fexcess-precision -optc-ffast-math #-}
{-
Haskell version of Jon Harrop's ray tracer.
Pretty much a straight conversion of the code.
Phil Armstrong 2007
-}

module Main (main) where

import Control.Monad(liftM2)
import Text.Printf(printf)
import Data.List(foldl') 
import Data.Word(Word8)
import qualified Data.ByteString.Lazy as B

data Vector = V !Double !Double !Double

(*|) s (V x y z) = V (s * x) (s * y) (s * z)
(+|) (V x y z) (V x2 y2 z2) = V (x + x2) (y + y2) (z + z2)
(-|) (V x y z) (V x2 y2 z2) = V (x - x2) (y - y2) (z - z2)
dot (V x y z) (V x2 y2 z2) = x * x2 + y * y2 + z * z2
unitise r = (1.0 / sqrt (dot r r)) *| r
infinity = encodeFloat 1 1024
delta  = sqrt $ encodeFloat 2 (-53)

data Ray = Ray !Vector !Vector
data Sphere = Sphere !Vector !Double
data Scene = S !Sphere | G !(Sphere,[Scene])

intersect ray@(Ray o d) scene = intersect' (infinity, (V 0.0 0.0 0.0)) scene
  where
  ray_sphere (Sphere c r) 
      | disc < 0.0 = infinity
      | t2 < 0.0   = infinity
      | t1 > 0.0   = t1 
      | otherwise  = t2
      where
        v     = c -| o
        b     = dot v d
        disc  = b*b - dot v v + r*r
        disc2 = sqrt disc
        t1    = b - disc2
        t2    = b + disc2

  intersect' hit@(lambda,_) (S sphere@(Sphere c _))
      | lambda' >= lambda = hit 
      | otherwise         = (lambda', unitise (o +| (lambda' *| d) -| c))
      where lambda' = ray_sphere sphere
  intersect' hit@(lambda,_) (G (sphere,scenes))
      | ray_sphere sphere >= lambda = hit
      | otherwise                   = foldr (flip intersect') hit scenes

ray_trace light ray@(Ray o d) scene
    | lambda == infinity = 0.0
    | g >= 0.0           = 0.0 
    | lambda' < infinity = 0.0
    | otherwise          = -g
    where
      (lambda,normal) = intersect ray scene
      g = dot normal light
      p = o +| (lambda *| d) +| (delta *| normal)
      (lambda',_) = intersect (Ray p ((-1.0) *| light)) scene

create 1 c r = S (Sphere c r)
create n c r =
    let
        r'  = 3.0 * r / (sqrt 12.0)
        mr' = -r'
        aux l (x,z) = (create (n-1) (c +| (V x r' z)) (r / 2.0)) : l
    in 
      G ((Sphere c (3.0 * r)), foldl' aux [S (Sphere c r)] 
         [(mr',mr'),(r',mr'),(mr',r'),(r',r')])

n  = 512::Int
ss = 4::Int
scene = create 6 (V 0.0 (-1.0) 0.0) 1.0
light = unitise (V (-1.0) (-3.0) 2.0)

pixel_vals :: Int -> Int -> [Double]
pixel_vals y x = do
  dx <- [0..(ss-1)]
  dy <- [0..(ss-1)]
  return $! let d = unitise (V (f x dx) (f y dy) (fromIntegral n))
            in ray_trace light (Ray (V 0.0 0.0 (-4.0)) d) scene

f :: Int -> Int -> Double
f a da = ((fromIntegral a) - (fromIntegral n) / 2.0) + ((fromIntegral da) / (fromIntegral ss))

picture :: [Word8]
picture = map (fromIntegral . floor . f)  $ map sum $ liftM2 pixel_vals  [(n-1),(n-2)..0] [0..(n-1)]
  where f x = 0.5 + ((255.0 * x) / (fromIntegral $ (ss*ss)))

main = do 
  printf "P5\n%d %d\n255\n" n n
  B.putStr $ B.pack picture
