Tuesday, April 13, 2010

Creating Custom Traversable implementations

One of the most talked about features of Scala 2.8 is the improved Collections libraries. Creating your own implementation is trivial, however if you want your new collection to behave the same way as all the included libraries there are a few tips you need to be aware of.

Note: All of these examples can either be ran in the REPL or put in a file and ran

Starting with the simple implementation:
  1. import scala.collection._
  2. import scala.collection.generic._
  3. class MyColl[A](seq : A*) extends Traversable[A] {
  4.     // only abstract method in traversable is foreach... easy :) 
  5.   def foreach[U](f: A => U) = util.Random.shuffle(seq.toSeq).foreach(f)
  6. }

This is a silly collection I admit but it is custom :).

This example works but if you test the result of a map operation (or any other operation that returns a new instance of the collection) you will notice it is not an instance of MyColl. This is expected because unless otherwise defined Traversable will return a new instance of Traversable.

To demonstrate run the following tests:
  1. val c = new MyColl(1, 2, 3)
  2. println (c mkString ",")
  3. println(c mkString ",")
  4. println(c drop 1 mkString ",")
  5. // this two next assertions fail (see following explanation)
  6. assert(c.drop(1).isInstanceOf[MyColl[_]])
  7. assert((c map {_ + 1}).isInstanceOf[MyColl[_]])

Both assertions will fail. The reason for these failures is because the collection is immutable which dictates by necessity that a new object must be returned from filter/map/etc... Since the Traversable trait returns instances of Traversable these two assertions fail. The easiest way to make these methods return an instance of MyColl is to make the following changes/additions.
  1. import scala.collection._
  2. import scala.collection.generic._
  3. /*
  4. Adding GenericTraversableTemplate will delegate the creation of new
  5. collections to the companion object.  Adding the trait and
  6. companion object causes all the new collections to be instances of MyColl
  7. */
  8. class MyColl[A](seq : A*) extends Traversable[A] 
  9.                              with GenericTraversableTemplate[A, MyColl] {
  10.   override def companion = MyColl
  11.   def foreach[U](f: A => U) = util.Random.shuffle(seq.toSeq).foreach(f)
  12. }
  13. // The TraversableFactory trait is required by GenericTraversableTemplate
  14. object MyColl extends TraversableFactory[MyColl] {
  15. /* 
  16. If you look at the signatures of many methods in TraversableLike they have an
  17. implicit parameter canBuildFrom.  This allows one to define how the returned collections
  18. are built.  For example one could make a list's map method return a Set
  19. In this case we define the default canBuildFrom for MyColl
  20. */
  21.   implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, MyColl[A]] = new GenericCanBuildFrom[A]
  22. /*  
  23. The method that builds the new collection.  This is a simple implementation
  24. but it works.  There are other implementations to assist with implementation if
  25. needed
  26. */
  27.   def newBuilder[A] = new scala.collection.mutable.LazyBuilder[A,MyColl[A]] {
  28.     def result = {
  29.       val data = parts.foldLeft(List[A]()){(l,n) => l ++ n}
  30.       new MyColl(data:_*)
  31.     }
  32.   }
  33. }

Now instances of MyColl will be created by the various filter/map/etc... methods and that is fine as long as the new object is not required at compile-time. But suppose we added a method to the class and want that accessible after applying methods like map and filter.

Adding val o : MyColl[Long] = c map {_.toLong} to the assertions will cause a compilation error since statically the class returned is Traversable[Long]. The fix is easy.

All that needs to be done is to add with TraversableLike[A, MyColl[A]] to MyColl and we are golden. There may be other methods as well but this works and is simple.

Note that the order in which the traits are mixed in is important. TraversableLike[A, MyColl[A]] must be mixed in after Traversable[A]. The reason is that we want methods like map and drop to return instances of MyColl (statically as well as dynamically). If the order was reversed then those methods would return Traversable event though statically the actual instances would still be MyColl.
  1. import scala.collection._
  2. import scala.collection.generic._
  3. class MyColl[A](seq : A*) extends Traversable[A]
  4.                              with GenericTraversableTemplate[A, MyColl] 
  5.                              with TraversableLike[A, MyColl[A]] {
  6.   override def companion = MyColl
  7.   def foreach[U](f: A => U) = util.Random.shuffle(seq.toSeq).foreach(f)
  8. }
  9. object MyColl extends TraversableFactory[MyColl] {  
  10.   implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, MyColl[A]] = new GenericCanBuildFrom[A]
  11.   def newBuilder[A] = new scala.collection.mutable.LazyBuilder[A,MyColl[A]] {
  12.     def result = {
  13.       val data = parts.foldLeft(List[A]()){(l,n) => l ++ n}
  14.       new MyColl(data:_*)
  15.     }
  16.   }
  17. }

Now add in a new method to demonstrate that the new collection works as desired and we are done.

The following is the complete implementation with the tests. You can put it in a file and run scala <filename> or paste it into a REPL
  1. import scala.collection._
  2. import scala.collection.generic._
  3. import scala.collection.mutable.{ Builder, ListBuffer }
  4. class MyColl[A](seq : A*) extends Traversable[A]
  5.                              with GenericTraversableTemplate[A, MyColl] 
  6.                              with TraversableLike[A, MyColl[A]] {
  7.   override def companion = MyColl
  8.   def foreach[U](f: A => U) = util.Random.shuffle(seq.toSeq).foreach(f)
  9.   def sayhi = println("hi!")
  10. }
  11. object MyColl extends TraversableFactory[MyColl] {  
  12.   implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, MyColl[A]] = new GenericCanBuildFrom[A]
  13.   def newBuilder[A] = new ListBuffer[A] mapResult (x => new MyColl(x:_*))
  14. }
  15. val c = new MyColl(1, 2, 3)
  16. println (c mkString ",")
  17. println(c mkString ",")
  18. assert(c.drop(1).isInstanceOf[MyColl[_]])
  19. assert((c map {_ + 1}).isInstanceOf[MyColl[_]])
  20. val o : MyColl[Int] = c filter {_ < 2}
  21. println(o mkString "," )
  22. o.sayhi

2 comments:

  1. I thought you were answering http://stackoverflow.com/questions/2534893/how-do-i-implement-a-collection-in-scala-2-8, but I didn't see this there. Do you plan to?

    ReplyDelete