diff --git a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 52779d2f..515ec380 100644 --- a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -170,6 +170,16 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider } } + /** + * Compares if two [[DataFrame]]s are equal without caring about order of rows, by + * finding elements in one DataFrame not in the other. The resulting DataFrame + * should be empty inferring the two DataFrames have the same elements. + */ + def assertDataFrameNoOrderEquals(expected: DataFrame, result: DataFrame) { + assertEmpty(expected.except(result).rdd.take(maxUnequalRowsToShow)) + assertEmpty(result.except(expected).rdd.take(maxUnequalRowsToShow)) + } + /** * Zip RDD's with precise indexes. This is used so we can join two DataFrame's * Rows together regardless of if the source is different but still compare diff --git a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index 0f92af6d..5e7160a2 100644 --- a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -37,6 +37,13 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { assertDataFrameEquals(input, input) } + test("dataframe should be equal with different order of rows") { + import sqlContext.implicits._ + val input = sc.parallelize(inputList).toDF + val reverseInput = sc.parallelize(inputList.reverse).toDF + assertDataFrameNoOrderEquals(input, reverseInput) + } + test("unequal dataframes should not be equal") { import sqlContext.implicits._ val input = sc.parallelize(inputList).toDF @@ -46,6 +53,18 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { } } + test("unequal dataframe with different order should not equal") { + import sqlContext.implicits._ + val input = sc.parallelize(inputList).toDF + val input2 = sc.parallelize(List(inputList.head)).toDF + intercept[org.scalatest.exceptions.TestFailedException] { + assertDataFrameNoOrderEquals(input, input2) + } + intercept[org.scalatest.exceptions.TestFailedException] { + assertDataFrameNoOrderEquals(input2, input) + } + } + test("dataframe approx expected") { import sqlContext.implicits._ val input = sc.parallelize(inputList).toDF