{-# LANGUAGE ForeignFunctionInterface #-}

module Dist (dist_slow, dist_slow', dist_fast, dist_fast_inlined, c_dist) where

import Data.Array.Vector (UArr, toU, sumU, zipWithU)
import Data.Foldable (foldl')

import Foreign
import Foreign.C.Types

-- Euclidean distance between two n-dimensional points
-- slow implementation using lists of Doubles
dist_slow :: [Double] -> [Double] -> Double
dist_slow p1 p2 = (sqrt . sum) ds
        where
                ds            = zipWith euclidean p1 p2
                euclidean x y = d*d
                        where
                                d = x-y

-- Euclidean distance between two n-dimensional points
-- same as dist_slow, but with strict sum
dist_slow' :: [Double] -> [Double] -> Double
dist_slow' p1 p2 = (sqrt . foldl' (+) 0.0) ds
        where
                ds            = zipWith euclidean p1 p2
                euclidean x y = d*d
                        where
                                d = x-y

-- Euclidean distance between two n-dimensional points
-- fast implementation using uvector package and streaming
dist_fast :: UArr Double -> UArr Double -> Double
dist_fast p1 p2 = sumDs `seq` sqrt sumDs
        where
                sumDs         = sumU ds
                ds            = zipWithU euclidean p1 p2
                euclidean x y = d*d
                        where
                                d = x-y

dist_fast_inlined :: UArr Double -> UArr Double -> Double
{-# INLINE dist_fast_inlined #-}
dist_fast_inlined p1 p2 = sumDs `seq` sqrt sumDs
        where
                sumDs         = sumU ds
                ds            = zipWithU euclidean p1 p2
                euclidean x y = d*d
                        where
                                d = x-y

foreign import ccall unsafe "dist" c_dist :: CInt -> Ptr CDouble -> Ptr CDouble -> CDouble
