Refactor palette generator ♻️

Simplified a lot of code which was unnecessarily generic.

Now using monads to manage the state of the random number generator
rather than passing it around by hand.

Also made some performance improvements, then increased the population
size so more combinations are tried in a similar length of time.
This commit is contained in:
Daniel Thwaites 2023-07-08 14:28:15 +01:00
parent c354350b9a
commit ba5565d698
No known key found for this signature in database
GPG key ID: D8AFC4BF05670F9D
6 changed files with 147 additions and 234 deletions

View file

@ -1,48 +1,27 @@
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
module Ai.Evolutionary ( EvolutionConfig(..), Species(..), evolve ) where module Ai.Evolutionary ( Species(..), evolve ) where
import Control.Applicative ( liftA2 ) import Data.Ord ( Down(Down), comparing )
import Data.Bifunctor ( first, second ) import Data.Vector ( (!) )
import Data.List ( mapAccumR, sortBy ) import qualified Data.Vector as V
import Data.Ord ( Down(Down, getDown), comparing ) import Data.Vector.Algorithms.Intro ( selectBy )
import System.Random ( RandomGen, mkStdGen, randomR ) import System.Random ( randomRIO )
import Text.Printf ( printf ) import Text.Printf ( printf )
{- | numSurvivors :: Int
Find every possible combination of two values, with the first value numSurvivors = 500
coming from one list and the second value coming from a different list.
-}
cartesianProduct :: [a] -> [b] -> [(a, b)]
cartesianProduct = liftA2 (,)
{- | numNewborns :: Int
Find every possible combination of two values, with both values coming numNewborns = 50000 - numSurvivors
from the same list. Values are allowed to be paired with themself.
-}
cartesianSquare :: [a] -> [(a, a)]
cartesianSquare as = as `cartesianProduct` as
-- | Pick a random element from a list using a random generator. mutationProbability :: Double
randomFromList :: (RandomGen r) => r -> [a] -> (a, r) mutationProbability = 0.75
randomFromList generator list
= let (index, generator') = randomR (0, length list - 1) generator
in (list !! index, generator')
{- | randomFromVector :: V.Vector a -> IO a
Map over a list, passing a random generator into the mapped randomFromVector vector = do
function each time it is called. A random generator is returned index <- randomRIO (0, V.length vector - 1)
along with the new list. return $ vector ! index
-}
mapWithGen :: (r -> a -> (r, b)) -> (r, [a]) -> (r, [b])
mapWithGen = uncurry . mapAccumR
unfoldWithGen :: (r -> (r, a)) -> Int -> r -> (r, [a])
unfoldWithGen _ 0 generator = (generator, [])
unfoldWithGen f size generator =
let (generator', as) = unfoldWithGen f (size - 1) generator
(generator'', a) = f generator'
in (generator'', a:as)
{- | {- |
A genotype is a value which is generated by the genetic algorithm. A genotype is a value which is generated by the genetic algorithm.
@ -51,141 +30,98 @@ The environment is used to specify the problem for which
we are trying to find the optimal genotype. we are trying to find the optimal genotype.
-} -}
class Species environment genotype where class Species environment genotype where
-- | Generate a new genotype at random. -- | Randomly generate a new genotype.
generate :: (RandomGen r) => environment -> r -> (r, genotype) generate :: environment -> IO genotype
-- | Randomly mutate a single genotype.
mutate :: environment -> genotype -> IO genotype
-- | Randomly combine two genotypes. -- | Randomly combine two genotypes.
crossover :: (RandomGen r) => environment -> r -> genotype -> genotype -> (r, genotype) crossover :: environment -> genotype -> genotype -> IO genotype
-- | Randomly mutate a genotype using the given environment.
mutate :: (RandomGen r) => environment -> r -> genotype -> (r, genotype)
-- | Score a genotype. Higher numbers are better. -- | Score a genotype. Higher numbers are better.
fitness :: environment -> genotype -> Double fitness :: environment -> genotype -> Double
-- | Parameters for the genetic algorithm. initialPopulation :: Species e g
data EvolutionConfig = EvolutionConfig
{ -- | The number of genotypes processed on each pass.
populationSize :: Int,
-- | How many genotypes make it through to the next pass.
survivors :: Int,
-- | The chance of a genotype being randomly changed
-- before crossover. Between 0 and 1.
mutationProbability :: Double,
-- | When the fitness score improves by less than this percentage,
-- the algorithm will stop.
changeThreshold :: Double
}
{- |
Randomly mutate the given genotype, if the mutation probability
from the 'EvolutionConfig' says yes.
-}
randomMutation :: (RandomGen r, Species e g)
=> e -- ^ Environment
-> EvolutionConfig
-> r -- ^ Random generator
-> g -- ^ Genotype to mutate
-> (r, g)
randomMutation environment config generator chromosome
= let (r, generator') = randomR (0.0, 1.0) generator
in if r <= mutationProbability config
then mutate environment generator' chromosome
else (generator', chromosome)
{- |
Select the fittest survivors from a population,
to be moved to the next pass of the algorithm.
-}
naturalSelection :: (Species e g)
=> e -- ^ Environment
-> EvolutionConfig
-> [g] -- ^ Original population
-> [(Double, g)] -- ^ Survivors with fitness scores
naturalSelection environment config
= take (survivors config)
. map (first getDown)
. sortBy (comparing fst)
-- Avoid computing fitness multiple times during sorting
-- Down reverses the sort order so that the best fitness comes first
. map (\genotype -> (Down $ fitness environment genotype, genotype))
-- | Run one pass of the genetic algorithm over a given population.
evolveGeneration :: (RandomGen r, Species e g)
=> e -- ^ Environment
-> EvolutionConfig
-> (r, [g]) -- ^ Random generator, population from previous generation
-> (r, Double, [g]) -- ^ New random generator, maximum fitness, new population
evolveGeneration environment config (generator, population)
= (newGenerator, maximum fitnesses, newPopulation)
where
(fitnesses, newPopulation) = unzip newPopulationWithFitness
(newGenerator, newPopulationWithFitness) =
second (naturalSelection environment config)
$ mapWithGen (randomMutation environment config)
$ unfoldWithGen randomCrossover (populationSize config) generator
randomCrossover gen = let (pair, gen') = randomFromList gen pairs
in (uncurry $ crossover environment gen') pair
pairs = cartesianSquare population
evolveUntilThreshold :: (RandomGen r, Species e g)
=> e -- ^ Environment
-> EvolutionConfig
-> [Double] -- ^ Fitnesses of previous generations
-> (r, [g]) -- ^ Random generator, population from previous generation
-> IO (r, [g]) -- ^ New random generator, final population
evolveUntilThreshold environment config fitnesses (generator, population) =
do
let (generator', fitness, population') =
evolveGeneration environment config (generator, population)
-- Begins at 0 on the first iteration
generationNumber = length fitnesses
fitnesses' = fitness : fitnesses
recentFitnesses = take 5 fitnesses'
{-
On the first iteration there is only one recent fitness, so the
improvement would be calculated as 0%. To prevent the algorithm
stopping immediately, we fall back to 100% in this case.
-}
change =
if generationNumber < 1
then 1
else 1 - (head recentFitnesses / last recentFitnesses);
printf "Generation: %3i Fitness: %7.1f Improvement: %5.1f%%\n"
generationNumber fitness (change * 100)
if change < changeThreshold config
then return (generator', population')
else evolveUntilThreshold environment config fitnesses' (generator', population')
{- |
Create the initial population, to be fed into the first
pass of the genetic algorithm.
-}
initialGeneration :: (RandomGen r, Species e g)
=> e -- ^ Environment => e -- ^ Environment
-> EvolutionConfig -> IO (V.Vector g) -- ^ Population
-> r -- ^ Random generator initialPopulation environment
-> (r, [g]) -- ^ New random generator, population = V.replicateM numSurvivors (generate environment)
initialGeneration environment config
= unfoldWithGen (generate environment) (survivors config)
-- | Run the full genetic algorithm. -- | Expand a population by crossovers followed by mutations.
evolvePopulation :: Species e g
=> e -- ^ Environment
-> V.Vector g -- ^ Survivors from previous generation
-> IO (V.Vector g) -- ^ New population
evolvePopulation environment population = do
let randomCrossover = do
a <- randomFromVector population
b <- randomFromVector population
crossover environment a b
randomMutation chromosome = do
r <- randomRIO (0.0, 1.0)
if r <= mutationProbability
then mutate environment chromosome
else return chromosome
newborns <- V.replicateM numNewborns randomCrossover
let nonElites = V.tail population V.++ newborns
nonElites' <- V.mapM randomMutation nonElites
return $ V.head population `V.cons` nonElites'
selectSurvivors :: Species e g
=> e -- ^ Environment
-> V.Vector g -- ^ Original population
-> (Double, V.Vector g) -- ^ Best fitness, survivors
selectSurvivors environment population =
let -- Fitness is stored to avoid calculating it for each comparison.
calculateFitness g = (fitness environment g, g)
getFitness = fst
getGenotype = snd
compareFitness = comparing $ Down . fst
-- Moves k best genotypes to the front, but doesn't sort them further.
selectBest k vector = selectBy compareFitness vector k
selected = V.modify (selectBest 1)
$ V.take numSurvivors
$ V.modify (selectBest numSurvivors)
$ V.map calculateFitness population
in ( getFitness $ V.head selected
, V.map getGenotype selected
)
shouldContinue :: [Double] -- ^ Fitness history
-> Bool
shouldContinue (x:y:_) = x /= y
shouldContinue _ = True
evolutionLoop :: Species e g
=> e -- ^ Environment
-> [Double] -- ^ Fitness history
-> V.Vector g -- ^ Survivors from previous generation
-> IO (V.Vector g) -- ^ Final population
evolutionLoop environment history survivors =
do
population <- evolvePopulation environment survivors
let (bestFitness, survivors') = selectSurvivors environment population
history' = bestFitness : history
printf "Generation: %3i Fitness: %7.1f\n"
(length history') (head history')
if shouldContinue history'
then evolutionLoop environment history' survivors'
else return survivors'
-- | Run the genetic algorithm.
evolve :: Species e g evolve :: Species e g
=> e -- ^ Environment => e -- ^ Environment
-> EvolutionConfig
-> IO g -- ^ Optimal genotype -> IO g -- ^ Optimal genotype
evolve environment config = do evolve environment = do
(_, population) <- population <- initialPopulation environment
evolveUntilThreshold environment config [] survivors <- evolutionLoop environment [] population
$ initialGeneration environment config return $ V.head survivors
$ mkStdGen 0 -- Fixed seed for determinism
return $ head population

View file

@ -1,20 +1,18 @@
module Data.Colour ( LAB(..), RGB(..), deltaE, lab2rgb, rgb2lab ) where module Data.Colour ( LAB(..), RGB(..), deltaE, lab2rgb, rgb2lab ) where
-- | Lightness A-B data LAB = LAB { lightness :: Double
data LAB a = LAB { lightness :: a , channelA :: Double
, channelA :: a , channelB :: Double
, channelB :: a }
}
-- | Red, Green, Blue data RGB = RGB { red :: Double
data RGB a = RGB { red :: a , green :: Double
, green :: a , blue :: Double
, blue :: a }
}
-- Based on https://github.com/antimatter15/rgb-lab/blob/master/color.js -- Based on https://github.com/antimatter15/rgb-lab/blob/master/color.js
deltaE :: (Floating a, Ord a) => LAB a -> LAB a -> a deltaE :: LAB -> LAB -> Double
deltaE (LAB l1 a1 b1) (LAB l2 a2 b2) = deltaE (LAB l1 a1 b1) (LAB l2 a2 b2) =
let deltaL = l1 - l2 let deltaL = l1 - l2
deltaA = a1 - a2 deltaA = a1 - a2
@ -32,7 +30,7 @@ deltaE (LAB l1 a1 b1) (LAB l2 a2 b2) =
in if i < 0 then 0 else sqrt i in if i < 0 then 0 else sqrt i
-- | Convert a 'LAB' colour to a 'RGB' colour -- | Convert a 'LAB' colour to a 'RGB' colour
lab2rgb :: (Floating a, Ord a) => LAB a -> RGB a lab2rgb :: LAB -> RGB
lab2rgb (LAB l a bx) = lab2rgb (LAB l a bx) =
let y = (l + 16) / 116 let y = (l + 16) / 116
x = a / 500 + y x = a / 500 + y
@ -52,7 +50,7 @@ lab2rgb (LAB l a bx) =
} }
-- | Convert a 'RGB' colour to a 'LAB' colour -- | Convert a 'RGB' colour to a 'LAB' colour
rgb2lab :: (Floating a, Ord a) => RGB a -> LAB a rgb2lab :: RGB -> LAB
rgb2lab (RGB r g b) = rgb2lab (RGB r g b) =
let r' = r / 255 let r' = r / 255
g' = g / 255 g' = g / 255

View file

@ -1,21 +1,14 @@
import Ai.Evolutionary ( EvolutionConfig(EvolutionConfig), evolve ) import Ai.Evolutionary ( evolve )
import Codec.Picture ( DynamicImage, Image, PixelRGB8, convertRGB8, readImage ) import Codec.Picture ( DynamicImage, convertRGB8, readImage )
import Data.Colour ( LAB, RGB(RGB), lab2rgb ) import Data.Colour ( lab2rgb )
import qualified Data.Vector as V import qualified Data.Vector as V
import Stylix.Output ( makeOutputTable ) import Stylix.Output ( makeOutputTable )
import Stylix.Palette ( ) import Stylix.Palette ( )
import System.Environment ( getArgs ) import System.Environment ( getArgs )
import System.Exit ( die ) import System.Exit ( die )
import System.Random ( setStdGen, mkStdGen )
import Text.JSON ( encode ) import Text.JSON ( encode )
-- | Run the genetic algorithm to generate a palette from the given image.
selectColours :: (Floating a, Real a)
=> String -- ^ Scheme type: "either", "light" or "dark"
-> Image PixelRGB8 -- ^ Source image
-> IO (V.Vector (LAB a)) -- ^ Generated palette
selectColours polarity image
= evolve (polarity, image) (EvolutionConfig 1000 100 0.5 0.01)
-- | Load an image file. -- | Load an image file.
loadImage :: String -- ^ Path to the file loadImage :: String -- ^ Path to the file
-> IO DynamicImage -> IO DynamicImage
@ -25,8 +18,11 @@ mainProcess :: (String, String, String) -> IO ()
mainProcess (polarity, input, output) = do mainProcess (polarity, input, output) = do
putStrLn $ "Processing " ++ input putStrLn $ "Processing " ++ input
-- Random numbers must be deterministic when running inside Nix.
setStdGen $ mkStdGen 0
image <- loadImage input image <- loadImage input
palette <- selectColours polarity (convertRGB8 image) palette <- evolve (polarity, convertRGB8 image)
let outputTable = makeOutputTable $ V.map lab2rgb palette let outputTable = makeOutputTable $ V.map lab2rgb palette
writeFile output $ encode outputTable writeFile output $ encode outputTable

View file

@ -6,23 +6,22 @@ import Data.Word ( Word8 )
import Text.JSON ( JSObject, toJSObject ) import Text.JSON ( JSObject, toJSObject )
import Text.Printf ( printf ) import Text.Printf ( printf )
-- | Convert any 'RGB' colour to store integers between 0 and 255. toHexNum :: Double -> Word8
toWord8 :: (RealFrac a) => RGB a -> RGB Word8 toHexNum = truncate
toWord8 (RGB r g b) = RGB (truncate r) (truncate g) (truncate b)
{- | {- |
Convert a colour to a hexdecimal string. Convert a colour to a hexdecimal string.
>>> toHex (RGB 255 255 255) >>> toHex (RGB 255 255 255)
"#ffffff" "ffffff"
-} -}
toHex :: RGB Word8 -> String toHex :: RGB -> String
toHex (RGB r g b) = printf "%02x%02x%02x" r g b toHex (RGB r g b) = printf "%02x%02x%02x" (toHexNum r) (toHexNum g) (toHexNum b)
-- | Convert a palette to the JSON format expected by Stylix's NixOS modules. -- | Convert a palette to the JSON format expected by Stylix's NixOS modules.
makeOutputTable :: (RealFrac a) => V.Vector (RGB a) -> JSObject String makeOutputTable :: V.Vector RGB -> JSObject String
makeOutputTable makeOutputTable
= toJSObject = toJSObject
. V.toList . V.toList
. V.imap (\i c -> (printf "base%02X" i, c)) . V.imap (\i c -> (printf "base%02X" i, c))
. V.map (toHex . toWord8) . V.map toHex

View file

@ -4,12 +4,11 @@ module Stylix.Palette ( ) where
import Ai.Evolutionary ( Species(..) ) import Ai.Evolutionary ( Species(..) )
import Codec.Picture ( Image(imageWidth, imageHeight), PixelRGB8(PixelRGB8), pixelAt ) import Codec.Picture ( Image(imageWidth, imageHeight), PixelRGB8(PixelRGB8), pixelAt )
import Data.Bifunctor ( second )
import Data.Colour ( LAB(lightness), RGB(RGB), deltaE, rgb2lab ) import Data.Colour ( LAB(lightness), RGB(RGB), deltaE, rgb2lab )
import Data.List ( delete ) import Data.List ( delete )
import Data.Vector ( (//) ) import Data.Vector ( (//) )
import qualified Data.Vector as V import qualified Data.Vector as V
import System.Random ( RandomGen, randomR ) import System.Random ( randomRIO )
-- | Extract the primary scale from a pallete. -- | Extract the primary scale from a pallete.
primary :: V.Vector a -> V.Vector a primary :: V.Vector a -> V.Vector a
@ -27,39 +26,23 @@ taken enough colours for a new palette.
alternatingZip :: V.Vector a -> V.Vector a -> V.Vector a alternatingZip :: V.Vector a -> V.Vector a -> V.Vector a
alternatingZip = V.izipWith (\i a b -> if even i then a else b) alternatingZip = V.izipWith (\i a b -> if even i then a else b)
-- | Select a random color from an image. randomFromImage :: Image PixelRGB8 -> IO LAB
randomFromImage :: (RandomGen r, Floating a, Num a, Ord a) randomFromImage image = do
=> r -- ^ Random generator x <- randomRIO (0, imageWidth image - 1)
-> Image PixelRGB8 y <- randomRIO (0, imageHeight image - 1)
-> (LAB a, r) -- ^ Chosen color, new random generator let (PixelRGB8 r g b) = pixelAt image x y
randomFromImage generator image color = RGB (fromIntegral r) (fromIntegral g) (fromIntegral b)
= let (x, generator') = randomR (0, imageWidth image - 1) generator return $ rgb2lab color
(y, generator'') = randomR (0, imageHeight image - 1) generator'
(PixelRGB8 r g b) = pixelAt image x y
color = RGB (fromIntegral r) (fromIntegral g) (fromIntegral b)
in (rgb2lab color, generator'')
instance (Floating a, Real a) => Species (String, (Image PixelRGB8)) (V.Vector (LAB a)) where instance Species (String, Image PixelRGB8) (V.Vector LAB) where
{- | generate (_, image) = V.replicateM 16 $ randomFromImage image
Palettes in the initial population are created by randomly
sampling 16 colours from the source image.
-}
generate (_, image) = generateColour 16
where generateColour 0 generator = (generator, V.empty)
generateColour n generator
= let (colour, generator') = randomFromImage generator image
in second (V.cons colour) $ generateColour (n - 1) generator'
crossover _ generator a b = (generator, alternatingZip a b) crossover _ a b = return $ alternatingZip a b
{- | mutate (_, image) palette = do
Mutation is done by replacing a random slot in the palette with index <- randomRIO (0, 15)
a new colour, which is randomly sampled from the source image. colour <- randomFromImage image
-} return $ palette // [(index, colour)]
mutate (_, image) generator palette
= let (index, generator') = randomR (0, 15) generator
(colour, generator'') = randomFromImage generator' image
in (generator'', palette // [(index, colour)])
fitness (polarity, _) palette fitness (polarity, _) palette
= realToFrac $ accentDifference - (primarySimilarity/10) - scheme = realToFrac $ accentDifference - (primarySimilarity/10) - scheme
@ -72,11 +55,11 @@ instance (Floating a, Real a) => Species (String, (Image PixelRGB8)) (V.Vector (
-- The accent colours should be as different as possible. -- The accent colours should be as different as possible.
accentDifference = minimum $ do accentDifference = minimum $ do
index_x <- [0 .. (V.length $ accent palette) - 1] index_x <- [0..7]
index_y <- delete index_x [0 .. (V.length $ accent palette) - 1] index_y <- delete index_x [0..7]
let x = (V.!) (accent palette) index_x let x = accent palette V.! index_x
let y = (V.!) (accent palette) index_y y = accent palette V.! index_y
return $ (deltaE x y) return $ deltaE x y
-- Helpers for the function below. -- Helpers for the function below.
lightnesses = V.map lightness palette lightnesses = V.map lightness palette

View file

@ -5,6 +5,7 @@ let
JuicyPixels JuicyPixels
json json
random random
vector-algorithms
]); ]);
# `nix build .#palette-generator.passthru.docs` and open in a web browser # `nix build .#palette-generator.passthru.docs` and open in a web browser