Skip to content

Commit

Permalink
rename test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoney committed Jun 28, 2024
1 parent 5b4a7e8 commit 755f165
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions blis/level3/dgemm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package level3

import (
"testing"

"github.com/jmoney/go-lapack/mat"
)

type TestCase struct {
A, B, C, C_Hat mat.General
A_Transpose mat.TransT
B_Transpose mat.TransT
alpha float64
beta float64
assertError bool
}

func Test_Dgemm(t *testing.T) {
tests := map[string]TestCase{
"A*B=C": {
A: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{1.0, 2.0, 3.0, 4.0}},
B: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{5.0, 6.0, 7.0, 8.0}},
C: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{0.0, 0.0, 0.0, 0.0}},
C_Hat: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{5.0, 7.0, 25.0, 35.0}},
A_Transpose: mat.NoTranspose,
B_Transpose: mat.NoTranspose,
alpha: 1.0,
beta: 0.0,
assertError: false,
},
"5A*B=C": {
A: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{1.0, 2.0, 3.0, 4.0}},
B: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{5.0, 6.0, 7.0, 8.0}},
C: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{0.0, 0.0, 0.0, 0.0}},
C_Hat: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{25.0, 35.0, 125.0, 175.0}},
A_Transpose: mat.NoTranspose,
B_Transpose: mat.NoTranspose,
alpha: 5.0,
beta: 0.0,
assertError: false,
},
"A*B=C Dimensions do not match": {
A: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{1.0, 2.0, 3.0, 4.0}},
B: mat.General{Rows: 3, Cols: 1, RowsStride: 2, ColsStride: 1, Data: []float64{5.0, 6.0, 7.0}},
C: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{0.0, 0.0, 0.0, 0.0}},
C_Hat: mat.General{Rows: 2, Cols: 2, RowsStride: 2, ColsStride: 1, Data: []float64{0.0, 0.0, 0.0, 0.0}},
A_Transpose: mat.NoTranspose,
B_Transpose: mat.NoTranspose,
alpha: 1.0,
beta: 0.0,
assertError: true,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := Bli_dgemm(test.A_Transpose, &test.alpha, test.A, test.B_Transpose, &test.beta, test.B, test.C)

if test.assertError && err == nil {
t.Errorf("%s failed: %v", t.Name(), err)
}

if !test.assertError {
for i := 0; i < len(test.C.Data); i++ {
if test.C.Data[i] != test.C_Hat.Data[i] {
t.Errorf("%s failed: %f != %f", t.Name(), test.C.Data[i], test.C_Hat.Data[i])
}
}
}
})
}
}

0 comments on commit 755f165

Please sign in to comment.