Skip to content

Commit

Permalink
cleanup flag setting. Add tests. Add validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoney committed Jun 28, 2024
1 parent 8795bb6 commit 5b4a7e8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .env
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
CGO_CFLAGS=-I/opt/homebrew/Cellar/blis/1.0/include
CGO_LDFLAGS=-L/opt/homebrew/Cellar/blis/1.0/lib
CGO_CFLAGS="-I/opt/homebrew/Cellar/blis/1.0/include"
CGO_LDFLAGS="-L/opt/homebrew/Cellar/blis/1.0/lib"
3 changes: 1 addition & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
{
"go.buildFlags": ["-ldflags", "-extldflags -w"],
"go.testFlags": ["-v", "-count=1", "-ldflags", "-extldflags -w"],
"go.testEnvFile": ".env",
"go.lintOnSave": "file",
"go.testEnvVars": {
"CGO_CFLAGS":"-I/opt/homebrew/Cellar/blis/1.0/include",
"CGO_LDFLAGS":"-L/opt/homebrew/Cellar/blis/1.0/lib",
},
"go.toolsEnvVars": {
"CGO_CFLAGS": "-I/opt/homebrew/Cellar/blis/1.0/include",
"CGO_CFLAGS":"-I/opt/homebrew/Cellar/blis/1.0/include",
"CGO_LDFLAGS":"-L/opt/homebrew/Cellar/blis/1.0/lib",
},
}
37 changes: 34 additions & 3 deletions blis/level3/level3.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,47 @@ package level3
// #include "blis.h"
import "C"
import (
"github.com/jmoney/go-blas/mat"
"fmt"

"github.com/jmoney/go-lapack/mat"
)

// Bli_dgemm maps to bli_dgemm in BLIS.
//
// Computes the operation: C := (alpha * A') * (beta * B') + C
func Bli_dgemm(transa mat.TransT, transb mat.TransT, alpha *float64, a mat.General, beta *float64, b mat.General, c mat.General) {
// Computes the double precision operation: C := (alpha * A') * (beta * B') + C
func Bli_dgemm(transa mat.TransT, alpha *float64, a mat.General, transb mat.TransT, beta *float64, b mat.General, c mat.General) error {

if err := validateDimensions(transa, a, transb, b, c); err != nil {
return err
}

C.bli_dgemm(C.trans_t(transa), C.trans_t(transb),
C.longlong(a.Rows), C.longlong(b.Cols), C.longlong(a.Cols),
(*C.double)(alpha), (*C.double)(&a.Data[0]), C.longlong(a.RowsStride), C.longlong(a.ColsStride),
(*C.double)(&b.Data[0]), C.longlong(b.RowsStride), C.longlong(b.ColsStride), (*C.double)(beta),
(*C.double)(&c.Data[0]), C.longlong(c.RowsStride), C.longlong(c.ColsStride))

return nil
}

func validateDimensions(transa mat.TransT, a mat.General, transb mat.TransT, b mat.General, c mat.General) error {
if transa == mat.NoTranspose && transb == mat.NoTranspose {
if a.Cols != b.Rows || a.Rows != c.Rows || b.Cols != c.Cols {
return fmt.Errorf("dimensions do not match")
}
} else if transa == mat.Transpose && transb == mat.NoTranspose {
if a.Rows != b.Rows || a.Cols != c.Rows || b.Cols != c.Cols {
return fmt.Errorf("dimensions do not match")
}
} else if transa == mat.NoTranspose && transb == mat.Transpose {
if a.Cols != b.Cols || a.Rows != c.Rows || b.Rows != c.Cols {
return fmt.Errorf("dimensions do not match")
}
} else if transa == mat.Transpose && transb == mat.Transpose {
if a.Rows != b.Cols || a.Cols != c.Rows || b.Rows != c.Cols {
return fmt.Errorf("dimensions do not match")
}
}

return nil
}
24 changes: 0 additions & 24 deletions blis/level3/level3_test.go

This file was deleted.

2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/jmoney/go-blas
module github.com/jmoney/go-lapack

go 1.22.4
6 changes: 3 additions & 3 deletions mat/mat.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package mat

type TransT int
type TransT uint32

const (
Transpose TransT = iota
Expand All @@ -11,7 +11,7 @@ type DimT int64
type IncT int64

type General struct {
Rows, Cols int
RowsStride, ColsStride int
Rows, Cols DimT
RowsStride, ColsStride IncT
Data []float64
}

0 comments on commit 5b4a7e8

Please sign in to comment.