module LinearAlgebra where

import Common (allCellSelections)

import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Layout as Layout
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Singular as Singular
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix (ShapeInt, (#*|))
import Numeric.LAPACK.Format ((##))

import qualified Data.Array.Comfort.Boxed as BoxedArray
import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Bool as ComfortSet
import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.Map as Map
import Data.Array.Comfort.Storable (Array)
import Data.Map (Map)
import Data.Set (Set)

import Data.Foldable (for_)
import Data.Tuple.HT (mapSnd)



{- $setup
>>> import qualified Data.Array.Comfort.Storable as Array
>>> import qualified Data.Map as Map
>>> import Data.Tuple.HT (mapSnd)
>>>
>>> myRound :: Double -> Integer
>>> myRound = round
-}


example0 :: [Double]
example0 = [3,1,4,1,5]

{-
        ...
      ... ...
    ... ... ...
  ... 100 ... .79
.31 .41 ... .26 ...
-}
example2 :: [[Maybe Double]]
example2 =
   let __ = Nothing; d = Just in
                [__] :
             [__,    __] :
          [__,   __,    __] :
       [__,  d 100,  __,  d 79] :
   [d 31, d 41,  __,   d 26,  __] :
   []

pyramid ::
   Array ShapeInt Double ->
   Array (Shape.LowerTriangular ShapeInt) Double
pyramid xs =
   let shape = Array.shape xs
       baseRow = Shape.size shape - 1
       arr =
         fmap (\(i,j) ->
            if i==baseRow
               then xs Array.! j
               else arr BoxedArray.! (i+1,j) + arr BoxedArray.! (i+1,j+1)) $
         BoxedArray.indices $ Shape.lowerTriangular shape
   in Array.fromBoxed arr

basis ::
   ShapeInt -> Matrix.General ShapeInt (Shape.LowerTriangular ShapeInt) Double
basis shape@(Shape.ZeroBased n) =
   Matrix.fromRows (Shape.lowerTriangular shape) $
   map (pyramid . Vector.unit shape) $ take n [0..]

addIndices :: [[Maybe a]] -> Map (Int,Int) a
addIndices puzzle = Map.fromList $ do
   (i, xs) <- zip [0..] puzzle
   (j, Just x) <- zip [0..] xs
   return ((i,j),x)

{- |
>>> mapSnd (map myRound . Array.toList) $ solve 3 $ Map.fromList [((0,0),8), ((2,0),1), ((2,2),3)]
(3,[8,3,5,1,2,3])
-}
solve ::
   Int -> Map (Int,Int) Double ->
   (Int, Array (Shape.LowerTriangular ShapeInt) Double)
solve n indexed =
   let fullBasis = Matrix.transpose $ basis (Matrix.shapeInt n)
       selected = Matrix.takeRowSet (Map.keysSet indexed) fullBasis
   in mapSnd ((fullBasis #*|) . Matrix.flattenColumn)
         (Singular.leastSquaresMinimumNormRCond 1e-5 selected $
          Matrix.singleColumn Layout.ColumnMajor $ Array.fromMap indexed)

solvable :: Int -> Set (Int,Int) -> Bool
solvable n =
   let shape = Matrix.shapeInt n
       fullBasis = basis shape
   in \ixs ->
         (n==) $ length $ takeWhile (1e-5<) $ Vector.toList $
         (#*| Vector.one ixs) $ Singular.values $
         Matrix.takeColumnSet ixs fullBasis

{- |
>>> map (length . solvables) [0..5]
[1,1,3,17,149,1824]
-}
solvables :: Int -> [Set (Int,Int)]
solvables n =
   let check = solvable n
   in filter check $ allCellSelections n

{-
Check, whether a sum pyramid contains a sub-pyramid of size k
with more than k given fields.
If yes, then the pyramid has redundancies in a sub-pyramid
and is not solvable.
-}
wellcrowded :: Int -> Set (Int,Int) -> Bool
wellcrowded n ixs =
   let triShape = Shape.lowerTriangular $ Matrix.shapeInt n
       set = ComfortSet.fromSet triShape ixs
       countSubTriangle (i,j) k =
         length $ filter (\(si,sj) -> ComfortSet.member (i+si,j+sj) set) $
         Shape.indices $ Shape.lowerTriangular $ Matrix.shapeInt k
   in and $ do
         (i,j) <- Shape.indices triShape
         k <- [0..n-i]
         return $ countSubTriangle (i,j) k <= k


{-
Check whether the linear independence criterion matches
the subpyramid criterion 'wellcrowded'.
Well, it does not, the smallest counterexample is:

   *
  . .
 . * .
* . . *
-}
counterexamples :: Int -> [Set (Int, Int)]
counterexamples n =
   let check = solvable n
   in filter (\ixs -> check ixs /= wellcrowded n ixs) $ allCellSelections n


boolMatrix :: Int -> Set (Int, Int) -> Matrix.Lower ShapeInt Float
boolMatrix n ixs =
   Triangular.fromLowerRowMajor $
   Array.fromAssociations 0
      (Shape.lowerTriangular $ Matrix.shapeInt n)
      (Map.toList $ Map.fromSet (const (1::Float)) ixs)


test :: IO ()
test = do
   Triangular.fromLowerRowMajor (pyramid (Vector.autoFromList example0))
      ## "%.0f"

   let xs = example2
    in mapSnd Triangular.fromLowerRowMajor (solve (length xs) (addIndices xs))
          ## "%.0f"

   putStrLn "\nsolvable:"
   let n = 3
    in for_ (solvables n) $ \ixs -> do
          putStrLn ""
          boolMatrix n ixs ## "%.0f"

   putStrLn "\nunsolvable:"
   for_ [0..5] $ \n ->
      for_ (counterexamples n) $ \ixs -> do
         putStrLn ""
         boolMatrix n ixs ## "%.0f"

{-
https://oeis.org/A014068

map (\n -> Comb.binomial (div (n*(n+1)) 2) n) [1..10::Integer]

Possibilities of choosing n numbers from a n-sized pyramid.
-}
