Refactor palette generator ♻️

Simplified a lot of code which was unnecessarily generic.

Now using monads to manage the state of the random number generator
rather than passing it around by hand.

Also made some performance improvements, then increased the population
size so more combinations are tried in a similar length of time.
This commit is contained in:
Daniel Thwaites 2023-07-08 14:28:15 +01:00
parent c354350b9a
commit ba5565d698
No known key found for this signature in database
GPG key ID: D8AFC4BF05670F9D
6 changed files with 147 additions and 234 deletions

View file

@ -1,48 +1,27 @@
{-# LANGUAGE MultiParamTypeClasses #-}
module Ai.Evolutionary ( EvolutionConfig(..), Species(..), evolve ) where
module Ai.Evolutionary ( Species(..), evolve ) where
import Control.Applicative ( liftA2 )
import Data.Bifunctor ( first, second )
import Data.List ( mapAccumR, sortBy )
import Data.Ord ( Down(Down, getDown), comparing )
import System.Random ( RandomGen, mkStdGen, randomR )
import Data.Ord ( Down(Down), comparing )
import Data.Vector ( (!) )
import qualified Data.Vector as V
import Data.Vector.Algorithms.Intro ( selectBy )
import System.Random ( randomRIO )
import Text.Printf ( printf )
{- |
Find every possible combination of two values, with the first value
coming from one list and the second value coming from a different list.
-}
cartesianProduct :: [a] -> [b] -> [(a, b)]
cartesianProduct = liftA2 (,)
numSurvivors :: Int
numSurvivors = 500
{- |
Find every possible combination of two values, with both values coming
from the same list. Values are allowed to be paired with themself.
-}
cartesianSquare :: [a] -> [(a, a)]
cartesianSquare as = as `cartesianProduct` as
numNewborns :: Int
numNewborns = 50000 - numSurvivors
-- | Pick a random element from a list using a random generator.
randomFromList :: (RandomGen r) => r -> [a] -> (a, r)
randomFromList generator list
= let (index, generator') = randomR (0, length list - 1) generator
in (list !! index, generator')
mutationProbability :: Double
mutationProbability = 0.75
{- |
Map over a list, passing a random generator into the mapped
function each time it is called. A random generator is returned
along with the new list.
-}
mapWithGen :: (r -> a -> (r, b)) -> (r, [a]) -> (r, [b])
mapWithGen = uncurry . mapAccumR
unfoldWithGen :: (r -> (r, a)) -> Int -> r -> (r, [a])
unfoldWithGen _ 0 generator = (generator, [])
unfoldWithGen f size generator =
let (generator', as) = unfoldWithGen f (size - 1) generator
(generator'', a) = f generator'
in (generator'', a:as)
randomFromVector :: V.Vector a -> IO a
randomFromVector vector = do
index <- randomRIO (0, V.length vector - 1)
return $ vector ! index
{- |
A genotype is a value which is generated by the genetic algorithm.
@ -51,141 +30,98 @@ The environment is used to specify the problem for which
we are trying to find the optimal genotype.
-}
class Species environment genotype where
-- | Generate a new genotype at random.
generate :: (RandomGen r) => environment -> r -> (r, genotype)
-- | Randomly generate a new genotype.
generate :: environment -> IO genotype
-- | Randomly mutate a single genotype.
mutate :: environment -> genotype -> IO genotype
-- | Randomly combine two genotypes.
crossover :: (RandomGen r) => environment -> r -> genotype -> genotype -> (r, genotype)
-- | Randomly mutate a genotype using the given environment.
mutate :: (RandomGen r) => environment -> r -> genotype -> (r, genotype)
crossover :: environment -> genotype -> genotype -> IO genotype
-- | Score a genotype. Higher numbers are better.
fitness :: environment -> genotype -> Double
-- | Parameters for the genetic algorithm.
data EvolutionConfig = EvolutionConfig
{ -- | The number of genotypes processed on each pass.
populationSize :: Int,
-- | How many genotypes make it through to the next pass.
survivors :: Int,
-- | The chance of a genotype being randomly changed
-- before crossover. Between 0 and 1.
mutationProbability :: Double,
-- | When the fitness score improves by less than this percentage,
-- the algorithm will stop.
changeThreshold :: Double
}
{- |
Randomly mutate the given genotype, if the mutation probability
from the 'EvolutionConfig' says yes.
-}
randomMutation :: (RandomGen r, Species e g)
initialPopulation :: Species e g
=> e -- ^ Environment
-> EvolutionConfig
-> r -- ^ Random generator
-> g -- ^ Genotype to mutate
-> (r, g)
randomMutation environment config generator chromosome
= let (r, generator') = randomR (0.0, 1.0) generator
in if r <= mutationProbability config
then mutate environment generator' chromosome
else (generator', chromosome)
-> IO (V.Vector g) -- ^ Population
initialPopulation environment
= V.replicateM numSurvivors (generate environment)
{- |
Select the fittest survivors from a population,
to be moved to the next pass of the algorithm.
-}
naturalSelection :: (Species e g)
-- | Expand a population by crossovers followed by mutations.
evolvePopulation :: Species e g
=> e -- ^ Environment
-> EvolutionConfig
-> [g] -- ^ Original population
-> [(Double, g)] -- ^ Survivors with fitness scores
naturalSelection environment config
= take (survivors config)
. map (first getDown)
. sortBy (comparing fst)
-- Avoid computing fitness multiple times during sorting
-- Down reverses the sort order so that the best fitness comes first
. map (\genotype -> (Down $ fitness environment genotype, genotype))
-> V.Vector g -- ^ Survivors from previous generation
-> IO (V.Vector g) -- ^ New population
evolvePopulation environment population = do
let randomCrossover = do
a <- randomFromVector population
b <- randomFromVector population
crossover environment a b
-- | Run one pass of the genetic algorithm over a given population.
evolveGeneration :: (RandomGen r, Species e g)
randomMutation chromosome = do
r <- randomRIO (0.0, 1.0)
if r <= mutationProbability
then mutate environment chromosome
else return chromosome
newborns <- V.replicateM numNewborns randomCrossover
let nonElites = V.tail population V.++ newborns
nonElites' <- V.mapM randomMutation nonElites
return $ V.head population `V.cons` nonElites'
selectSurvivors :: Species e g
=> e -- ^ Environment
-> EvolutionConfig
-> (r, [g]) -- ^ Random generator, population from previous generation
-> (r, Double, [g]) -- ^ New random generator, maximum fitness, new population
evolveGeneration environment config (generator, population)
= (newGenerator, maximum fitnesses, newPopulation)
where
(fitnesses, newPopulation) = unzip newPopulationWithFitness
-> V.Vector g -- ^ Original population
-> (Double, V.Vector g) -- ^ Best fitness, survivors
selectSurvivors environment population =
let -- Fitness is stored to avoid calculating it for each comparison.
calculateFitness g = (fitness environment g, g)
getFitness = fst
getGenotype = snd
compareFitness = comparing $ Down . fst
(newGenerator, newPopulationWithFitness) =
second (naturalSelection environment config)
$ mapWithGen (randomMutation environment config)
$ unfoldWithGen randomCrossover (populationSize config) generator
-- Moves k best genotypes to the front, but doesn't sort them further.
selectBest k vector = selectBy compareFitness vector k
randomCrossover gen = let (pair, gen') = randomFromList gen pairs
in (uncurry $ crossover environment gen') pair
selected = V.modify (selectBest 1)
$ V.take numSurvivors
$ V.modify (selectBest numSurvivors)
$ V.map calculateFitness population
pairs = cartesianSquare population
in ( getFitness $ V.head selected
, V.map getGenotype selected
)
evolveUntilThreshold :: (RandomGen r, Species e g)
shouldContinue :: [Double] -- ^ Fitness history
-> Bool
shouldContinue (x:y:_) = x /= y
shouldContinue _ = True
evolutionLoop :: Species e g
=> e -- ^ Environment
-> EvolutionConfig
-> [Double] -- ^ Fitnesses of previous generations
-> (r, [g]) -- ^ Random generator, population from previous generation
-> IO (r, [g]) -- ^ New random generator, final population
evolveUntilThreshold environment config fitnesses (generator, population) =
-> [Double] -- ^ Fitness history
-> V.Vector g -- ^ Survivors from previous generation
-> IO (V.Vector g) -- ^ Final population
evolutionLoop environment history survivors =
do
let (generator', fitness, population') =
evolveGeneration environment config (generator, population)
population <- evolvePopulation environment survivors
-- Begins at 0 on the first iteration
generationNumber = length fitnesses
let (bestFitness, survivors') = selectSurvivors environment population
history' = bestFitness : history
fitnesses' = fitness : fitnesses
recentFitnesses = take 5 fitnesses'
printf "Generation: %3i Fitness: %7.1f\n"
(length history') (head history')
{-
On the first iteration there is only one recent fitness, so the
improvement would be calculated as 0%. To prevent the algorithm
stopping immediately, we fall back to 100% in this case.
-}
change =
if generationNumber < 1
then 1
else 1 - (head recentFitnesses / last recentFitnesses);
if shouldContinue history'
then evolutionLoop environment history' survivors'
else return survivors'
printf "Generation: %3i Fitness: %7.1f Improvement: %5.1f%%\n"
generationNumber fitness (change * 100)
if change < changeThreshold config
then return (generator', population')
else evolveUntilThreshold environment config fitnesses' (generator', population')
{- |
Create the initial population, to be fed into the first
pass of the genetic algorithm.
-}
initialGeneration :: (RandomGen r, Species e g)
=> e -- ^ Environment
-> EvolutionConfig
-> r -- ^ Random generator
-> (r, [g]) -- ^ New random generator, population
initialGeneration environment config
= unfoldWithGen (generate environment) (survivors config)
-- | Run the full genetic algorithm.
-- | Run the genetic algorithm.
evolve :: Species e g
=> e -- ^ Environment
-> EvolutionConfig
-> IO g -- ^ Optimal genotype
evolve environment config = do
(_, population) <-
evolveUntilThreshold environment config []
$ initialGeneration environment config
$ mkStdGen 0 -- Fixed seed for determinism
return $ head population
evolve environment = do
population <- initialPopulation environment
survivors <- evolutionLoop environment [] population
return $ V.head survivors

View file

@ -1,20 +1,18 @@
module Data.Colour ( LAB(..), RGB(..), deltaE, lab2rgb, rgb2lab ) where
-- | Lightness A-B
data LAB a = LAB { lightness :: a
, channelA :: a
, channelB :: a
data LAB = LAB { lightness :: Double
, channelA :: Double
, channelB :: Double
}
-- | Red, Green, Blue
data RGB a = RGB { red :: a
, green :: a
, blue :: a
data RGB = RGB { red :: Double
, green :: Double
, blue :: Double
}
-- Based on https://github.com/antimatter15/rgb-lab/blob/master/color.js
deltaE :: (Floating a, Ord a) => LAB a -> LAB a -> a
deltaE :: LAB -> LAB -> Double
deltaE (LAB l1 a1 b1) (LAB l2 a2 b2) =
let deltaL = l1 - l2
deltaA = a1 - a2
@ -32,7 +30,7 @@ deltaE (LAB l1 a1 b1) (LAB l2 a2 b2) =
in if i < 0 then 0 else sqrt i
-- | Convert a 'LAB' colour to a 'RGB' colour
lab2rgb :: (Floating a, Ord a) => LAB a -> RGB a
lab2rgb :: LAB -> RGB
lab2rgb (LAB l a bx) =
let y = (l + 16) / 116
x = a / 500 + y
@ -52,7 +50,7 @@ lab2rgb (LAB l a bx) =
}
-- | Convert a 'RGB' colour to a 'LAB' colour
rgb2lab :: (Floating a, Ord a) => RGB a -> LAB a
rgb2lab :: RGB -> LAB
rgb2lab (RGB r g b) =
let r' = r / 255
g' = g / 255

View file

@ -1,21 +1,14 @@
import Ai.Evolutionary ( EvolutionConfig(EvolutionConfig), evolve )
import Codec.Picture ( DynamicImage, Image, PixelRGB8, convertRGB8, readImage )
import Data.Colour ( LAB, RGB(RGB), lab2rgb )
import Ai.Evolutionary ( evolve )
import Codec.Picture ( DynamicImage, convertRGB8, readImage )
import Data.Colour ( lab2rgb )
import qualified Data.Vector as V
import Stylix.Output ( makeOutputTable )
import Stylix.Palette ( )
import System.Environment ( getArgs )
import System.Exit ( die )
import System.Random ( setStdGen, mkStdGen )
import Text.JSON ( encode )
-- | Run the genetic algorithm to generate a palette from the given image.
selectColours :: (Floating a, Real a)
=> String -- ^ Scheme type: "either", "light" or "dark"
-> Image PixelRGB8 -- ^ Source image
-> IO (V.Vector (LAB a)) -- ^ Generated palette
selectColours polarity image
= evolve (polarity, image) (EvolutionConfig 1000 100 0.5 0.01)
-- | Load an image file.
loadImage :: String -- ^ Path to the file
-> IO DynamicImage
@ -25,8 +18,11 @@ mainProcess :: (String, String, String) -> IO ()
mainProcess (polarity, input, output) = do
putStrLn $ "Processing " ++ input
-- Random numbers must be deterministic when running inside Nix.
setStdGen $ mkStdGen 0
image <- loadImage input
palette <- selectColours polarity (convertRGB8 image)
palette <- evolve (polarity, convertRGB8 image)
let outputTable = makeOutputTable $ V.map lab2rgb palette
writeFile output $ encode outputTable

View file

@ -6,23 +6,22 @@ import Data.Word ( Word8 )
import Text.JSON ( JSObject, toJSObject )
import Text.Printf ( printf )
-- | Convert any 'RGB' colour to store integers between 0 and 255.
toWord8 :: (RealFrac a) => RGB a -> RGB Word8
toWord8 (RGB r g b) = RGB (truncate r) (truncate g) (truncate b)
toHexNum :: Double -> Word8
toHexNum = truncate
{- |
Convert a colour to a hexdecimal string.
>>> toHex (RGB 255 255 255)
"#ffffff"
"ffffff"
-}
toHex :: RGB Word8 -> String
toHex (RGB r g b) = printf "%02x%02x%02x" r g b
toHex :: RGB -> String
toHex (RGB r g b) = printf "%02x%02x%02x" (toHexNum r) (toHexNum g) (toHexNum b)
-- | Convert a palette to the JSON format expected by Stylix's NixOS modules.
makeOutputTable :: (RealFrac a) => V.Vector (RGB a) -> JSObject String
makeOutputTable :: V.Vector RGB -> JSObject String
makeOutputTable
= toJSObject
. V.toList
. V.imap (\i c -> (printf "base%02X" i, c))
. V.map (toHex . toWord8)
. V.map toHex

View file

@ -4,12 +4,11 @@ module Stylix.Palette ( ) where
import Ai.Evolutionary ( Species(..) )
import Codec.Picture ( Image(imageWidth, imageHeight), PixelRGB8(PixelRGB8), pixelAt )
import Data.Bifunctor ( second )
import Data.Colour ( LAB(lightness), RGB(RGB), deltaE, rgb2lab )
import Data.List ( delete )
import Data.Vector ( (//) )
import qualified Data.Vector as V
import System.Random ( RandomGen, randomR )
import System.Random ( randomRIO )
-- | Extract the primary scale from a pallete.
primary :: V.Vector a -> V.Vector a
@ -27,39 +26,23 @@ taken enough colours for a new palette.
alternatingZip :: V.Vector a -> V.Vector a -> V.Vector a
alternatingZip = V.izipWith (\i a b -> if even i then a else b)
-- | Select a random color from an image.
randomFromImage :: (RandomGen r, Floating a, Num a, Ord a)
=> r -- ^ Random generator
-> Image PixelRGB8
-> (LAB a, r) -- ^ Chosen color, new random generator
randomFromImage generator image
= let (x, generator') = randomR (0, imageWidth image - 1) generator
(y, generator'') = randomR (0, imageHeight image - 1) generator'
(PixelRGB8 r g b) = pixelAt image x y
randomFromImage :: Image PixelRGB8 -> IO LAB
randomFromImage image = do
x <- randomRIO (0, imageWidth image - 1)
y <- randomRIO (0, imageHeight image - 1)
let (PixelRGB8 r g b) = pixelAt image x y
color = RGB (fromIntegral r) (fromIntegral g) (fromIntegral b)
in (rgb2lab color, generator'')
return $ rgb2lab color
instance (Floating a, Real a) => Species (String, (Image PixelRGB8)) (V.Vector (LAB a)) where
{- |
Palettes in the initial population are created by randomly
sampling 16 colours from the source image.
-}
generate (_, image) = generateColour 16
where generateColour 0 generator = (generator, V.empty)
generateColour n generator
= let (colour, generator') = randomFromImage generator image
in second (V.cons colour) $ generateColour (n - 1) generator'
instance Species (String, Image PixelRGB8) (V.Vector LAB) where
generate (_, image) = V.replicateM 16 $ randomFromImage image
crossover _ generator a b = (generator, alternatingZip a b)
crossover _ a b = return $ alternatingZip a b
{- |
Mutation is done by replacing a random slot in the palette with
a new colour, which is randomly sampled from the source image.
-}
mutate (_, image) generator palette
= let (index, generator') = randomR (0, 15) generator
(colour, generator'') = randomFromImage generator' image
in (generator'', palette // [(index, colour)])
mutate (_, image) palette = do
index <- randomRIO (0, 15)
colour <- randomFromImage image
return $ palette // [(index, colour)]
fitness (polarity, _) palette
= realToFrac $ accentDifference - (primarySimilarity/10) - scheme
@ -72,11 +55,11 @@ instance (Floating a, Real a) => Species (String, (Image PixelRGB8)) (V.Vector (
-- The accent colours should be as different as possible.
accentDifference = minimum $ do
index_x <- [0 .. (V.length $ accent palette) - 1]
index_y <- delete index_x [0 .. (V.length $ accent palette) - 1]
let x = (V.!) (accent palette) index_x
let y = (V.!) (accent palette) index_y
return $ (deltaE x y)
index_x <- [0..7]
index_y <- delete index_x [0..7]
let x = accent palette V.! index_x
y = accent palette V.! index_y
return $ deltaE x y
-- Helpers for the function below.
lightnesses = V.map lightness palette

View file

@ -5,6 +5,7 @@ let
JuicyPixels
json
random
vector-algorithms
]);
# `nix build .#palette-generator.passthru.docs` and open in a web browser