Created
October 16, 2014 21:19
-
-
Save helenwilliamson/eacf7e87a9b9c7888f24 to your computer and use it in GitHub Desktop.
Scala Dojo on 16th October
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.File | |
import scala.io.Source | |
object MarkovChainApp extends App { | |
val tokenisedText = Source.fromFile(new File(args(0))).mkString.split(" ").map(_.trim) | |
val wordMap : Map[String, List[String]] = tokenisedText.sliding(2).foldLeft(Map[String, List[String]]().withDefaultValue(List.empty[String])) { | |
(dict, words) => { | |
dict.updated(words(0), words(1) :: dict(words(0)) ) | |
} | |
} | |
def nextWord( aWord: String ): String = { | |
val words = wordMap(aWord) | |
words( scala.util.Random.nextInt(words.length)) | |
} | |
val seed = wordMap.keySet.toList(scala.util.Random.nextInt(wordMap.size)) | |
val snt = (0 to scala.util.Random.nextInt(100)).foldRight( (seed, "" ) ) { | |
case (_, (sed, sentence)) => | |
val nxtWord = nextWord( sed) | |
(nxtWord, sentence + " " + nxtWord) | |
} | |
def recursiveFollowWords(seed: String): List[String] = { | |
val words = wordMap(seed) | |
if(words.nonEmpty) { | |
val next = words( scala.util.Random.nextInt(words.length)) | |
if(next.endsWith(".") && scala.util.Random.nextInt(100) < 50) { | |
next :: Nil | |
} else next :: recursiveFollowWords(next) | |
} else Nil | |
} | |
println(recursiveFollowWords(seed).mkString(" ")) | |
} | |
object MarkovChainWithMoreWords extends App { | |
val windowSize = args(1).toInt | |
val tokenisedText = Source.fromFile(new File(args(0))).mkString.split(" ").map(_.trim) | |
val wordMap : Map[List[String], List[String]] = tokenisedText.sliding(windowSize).foldLeft(Map[List[String], List[String]]().withDefaultValue(List.empty[String])) { | |
(dict, words) => { | |
val key = words.take(windowSize - 1).toList | |
dict.updated(key, words.last :: dict(key) ) | |
} | |
} | |
def nextWord( aWord: List[String] ): String = { | |
val words = wordMap(aWord) | |
words( scala.util.Random.nextInt(words.length)) | |
} | |
val seed = wordMap.keySet.toList(scala.util.Random.nextInt(wordMap.size)) | |
def recursiveFollowWords(seed: List[String]): List[String] = { | |
val words = wordMap(seed) | |
if(words.nonEmpty) { | |
val next = words( scala.util.Random.nextInt(words.length)) | |
if(next.endsWith(".") && scala.util.Random.nextInt(100) < 50) { | |
next :: Nil | |
} else next :: recursiveFollowWords(seed.drop(1) :+ next) | |
} else Nil | |
} | |
println(recursiveFollowWords(seed).mkString(" ")) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment