-- (c) David F. Place 2006
module Main where
import Data.List hiding (insert)
import Data.Bits
import Data.Array
import System.IO
import System

type Sudoku = [[Int]]
type SudokuSets = ([Set],[Set],[Set])
type Blank = (Int,Int,[Int])
type Set = Int

{-

Usage: runghc sudoku.hs [line|grid] 

line format example:

600500700000060040000003092003002010900000004020100500860300000050010000009004003

grid format example:

0 0 7 0 2 0 3 0 0 
5 0 0 8 0 0 0 0 2 
0 4 0 6 0 0 8 0 0 
0 0 8 0 0 0 2 4 0 
9 0 0 0 4 0 0 0 7 
0 1 2 0 0 0 6 0 0 
0 0 4 0 0 2 0 3 0 
1 0 0 0 0 5 0 0 6 
0 0 6 0 1 0 4 0 0 

-}

main :: IO ()
main =
    do 
      args <- getArgs
      input <- hGetContents stdin
      sudoku $ readSudoku args input

readSudoku :: [String] -> String -> Sudoku
readSudoku ["line"] input = 
    takeBy 9 $ map (read::String->Int) $ takeBy 1 $ head $ lines input
readSudoku _ input = 
    map (map (read::String->Int)) $ map words $ lines input

sudoku :: Sudoku -> IO ()
sudoku s = ((mapM_ putStrLn) . (check s) . (take 1) . solveSudoku) s

solveSudoku :: Sudoku -> [Sudoku]
solveSudoku s = maybe [s] f $ findBlank s
    where 
      f (_,_,[]) = [] 
      f (r,c,ns) = concatMap (solveSudoku . (replace s r c)) ns 

replace :: Sudoku -> Int -> Int -> Int -> Sudoku
replace s r c x = insertAt r f s
    where 
      f (_,after) = insertAt c (const x) $ head after

insertAt :: Int -> (([a],[a]) -> a) -> [a] -> [a]
insertAt n f xs = before ++ [f parts] ++ (tail after)
    where 
      parts@(before,after) = splitAt n xs

findBlank :: Sudoku -> Maybe Blank
findBlank s = case mapBlanks f s of
                [] -> Nothing
                bs -> Just $ minimumBy comp bs 
    where 
      comp (_,_,a) (_,_,b) = compareLength a b
      f r c = (r,c,toList $ blankSet sets r c)
      sets = sudokuSets s

mapBlanks :: (Int -> Int -> a) -> Sudoku -> [a]
mapBlanks f s = concatMap g $ ix s
    where 
      ix xs = zip [0..8] xs
      g (r,row) = concatMap (h r) $ ix row
      h r (c,x) = if x==0
                     then [f r c]
                     else []

compareLength :: [a] -> [a] -> Ordering
compareLength [] [] = EQ
compareLength _  [] = GT
compareLength [] _  = LT
compareLength (_:xs) (_:ys) = 
    compareLength xs ys

sudokuSets :: Sudoku -> SudokuSets                               
sudokuSets s = (rows,columns,cells)
    where 
      rows = map fromList s
      columns = map fromList $ transpose s
      cells = map fromList $ cellContents s

universe :: Set
universe = fromList [1..9]

blankSet :: SudokuSets -> Int -> Int -> Set
blankSet (rows,columns,cells) r c = universe `xor` constraints
    where 
      constraints = (cells !! (cIx r c)) .|. (rows !! r) .|. (columns !! c)

cellContents :: Sudoku -> [[Int]]
cellContents s = concatMap f $ h s
    where 
      f [a,b,c] = zipWith3 g (h a) (h b) (h c)
      g a b c = a++b++c
      h = takeBy 3

cIx :: Int -> Int -> Int
cIx r c = cells ! (r,c)

cells :: Array (Int, Int) Int
cells = array ((0,0),(8,8)) inits
    where 
      inits = [((r,c),cellIndex r c) | r <- ns, c <- ns]
      ns = [0..8]

cellIndex :: Int -> Int -> Int
cellIndex r c 
    | r < 3 = f c
    | r < 6 = g c 
    | r < 9 = h c 
    where 
      f c 
          | c < 3 = 0
          | c < 6 = 1
          | c < 9 = 2
      g c
          | c < 3 = 3
          | c < 6 = 4
          | c < 9 = 5
      h c 
          | c < 3 = 6
          | c < 6 = 7
          | c < 9 = 8

showSudoku :: Sudoku -> String
showSudoku s = stretch . (intersperse ["---+---+---\n"]) . (stitch knot) . gather $ s
    where 
      sewRow = stretch . (intersperse ["|"]) . (stitch show) . gather
      knot r = (sewRow r)++['\n']
      stretch = concat . concat
      stitch f = map (map f)
      gather = takeBy 3

check puzzle [] = [showSudoku puzzle,"No solutions."]
check puzzle [solution]
      | solution `solves` puzzle = 
          ["Puzzle:",showSudoku puzzle,"Solution:",showSudoku solution]
      | otherwise = ["Program Error.  Incorrect Solution!"]

empty :: Set
empty = 0

fromList :: [Int] -> Set
fromList xs = foldl' setBit empty xs

toList :: Set -> [Int]
toList set = fst $ foldl' f ([],set `shiftR` 1) [1..9]
    where 
      f (result,set) x =
          if (set .&. 1) == 1
             then (x:result,set `shiftR` 1)
             else (result,set `shiftR` 1)            

insert :: Set->Int->Int
insert set 0 = set
insert set elem = setBit set elem

solvedSudokuSets :: SudokuSets
solvedSudokuSets = (sets,sets,sets)
    where 
      sets = replicate 9 universe

isSolution :: Sudoku -> Bool
isSolution s = (sudokuSets s) == solvedSudokuSets

solves :: Sudoku -> Sudoku -> Bool
solves a b = isSolution a && comp a b
    where 
      comp a b = all p $ zip (concat a) (concat b)
      p (_,0) = True
      p (solution,puzzle) = solution == puzzle

takeBy :: Int -> [a] -> [[a]]
takeBy n xs = unfoldr f xs
    where 
      f [] = Nothing
      f xs = Just $ splitAt n xs
