Skip to content

Commit

Permalink
Feat(comparison): add model comparison from ormb metadata (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
FogDong authored Sep 15, 2020
1 parent 2a75e8a commit 5ceb2ed
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 2 deletions.
53 changes: 53 additions & 0 deletions pkg/registry/apis/v1alpha1/descriptors/comparison.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package descriptors

import (
"context"

"github.com/caicloud/nirvana/definition"

"github.com/kleveross/klever-model-registry/pkg/registry/comparison"
"github.com/kleveross/klever-model-registry/pkg/registry/paging"
)

func init() {
register(comparisonAPI)
}

var comparisonAPI = definition.Descriptor{
Description: "APIs for model comparison",
Children: []definition.Descriptor{
{
Path: "/comparisons",
Definitions: []definition.Definition{generateComparison},
},
{
Path: "/comparativedocument",
Definitions: []definition.Definition{downloadComparison},
},
},
}

var generateComparison = definition.Definition{
Method: definition.List,
Description: "Generate Comparison",
Parameters: []definition.Parameter{
definition.BodyParameterFor("Comparison Body"),
paging.PageDefinitionParameter(),
},
Results: definition.DataErrorResults("generate comparison"),
Function: func(ctx context.Context, models comparison.Comparison, opt *paging.ListOption) (*comparison.ORMBModelList, error) {
return comparison.Generator(ctx, models, opt)
},
}

var downloadComparison = definition.Definition{
Method: definition.Get,
Description: "Download Comparison",
Results: []definition.Result{definition.ErrorResult()},
Parameters: []definition.Parameter{
definition.BodyParameterFor("Comparison Body"),
},
Function: func(ctx context.Context, models comparison.Comparison) error {
return comparison.DownloadCSVFile(ctx, models)
},
}
193 changes: 193 additions & 0 deletions pkg/registry/comparison/comparison.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package comparison

import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"os"
"time"

"github.com/caicloud/nirvana/log"
ormbmodel "github.com/kleveross/ormb/pkg/model"

"github.com/kleveross/klever-model-registry/pkg/common"
"github.com/kleveross/klever-model-registry/pkg/registry/errors"
"github.com/kleveross/klever-model-registry/pkg/registry/harbor"
"github.com/kleveross/klever-model-registry/pkg/registry/paging"
"github.com/kleveross/klever-model-registry/pkg/util"
)

// Generator list models' metadata and compare
func Generator(ctx context.Context, models Comparison, opt *paging.ListOption) (*ORMBModelList, error) {
metaList, err := composeComparison(models.Models)
if err != nil {
return nil, err
}
return toORMBModelList(metaList, opt), nil
}

// toModelJobList is convert to ModelJobList struct.
func toORMBModelList(items []*ormbmodel.Model, opt *paging.ListOption) *ORMBModelList {
datas := paging.Page(items, opt)
modelList := &ORMBModelList{
ListMeta: paging.ListMeta{
TotalItems: datas.TotalItems,
},
Items: []*ormbmodel.Model{},
}

for _, d := range datas.Items {
modelList.Items = append(modelList.Items, d.(*ormbmodel.Model))
}
return modelList
}

func composeComparison(models []ComparisonModel) ([]*ormbmodel.Model, error) {
proxy := harbor.NewProxy(common.ORMBDomain, common.ORMBUserName, common.ORMBPassword)
metaList := make([]*ormbmodel.Model, 0)
for _, model := range models {
artifacts, err := proxy.ListArtifacts(model.Project, model.Name)
if err != nil {
return nil, err
}
for _, artifact := range artifacts {
for _, tag := range artifact.Tags {
if tag.Name == model.Tag {
manifest, err := json.Marshal(artifact.ExtraAttrs)
if err != nil {
return nil, err
}
var meta ormbmodel.Metadata
if err := json.Unmarshal(manifest, &meta); err != nil {
return nil, err
}
metaList = append(metaList, &ormbmodel.Model{
Path: fmt.Sprintf("%s/%s:%s", model.Project, model.Name, model.Tag),
Metadata: &meta,
})
break
}
}
}
}

return metaList, nil
}

// DownloadCSVFile downloads the comparison csv file
func DownloadCSVFile(ctx context.Context, models Comparison) error {
metaList, err := composeComparison(models.Models)
if err != nil {
return err
}
err = createCSVFile(ctx, metaList)
if err != nil {
return err
}

return nil
}

func createCSVFile(ctx context.Context, metas []*ormbmodel.Model) error {
fileName := fmt.Sprintf("%d.csv", time.Now().Unix())

file, err := os.Create(fileName)
if err != nil {
return fmt.Errorf("open file is failed, err: %s", err.Error())
}

defer func() {
if err := file.Close(); err != nil {
log.Error("close file is failed, err: %s", err.Error())
}
if err := os.Remove(fileName); err != nil {
log.Error("remove file is failed, err: %s", err.Error())
}
}()

// Write UTF-8 BOM mainly for unidentifiable Chinese code,
// see https://stackoverflow.com/questions/2223882/whats-the-difference-between-utf-8-and-utf-8-without-bom
_, err = file.WriteString("\xEF\xBB\xBF")
if err != nil {
return fmt.Errorf("write utf-8 bom err: %s", err.Error())
}
csvWrite := csv.NewWriter(file)

content, err := composeCSVFileContent(metas)
if err != nil {
return err
}
err = csvWrite.WriteAll(content)
if err != nil {
return err
}
csvWrite.Flush()

if _, err := file.Seek(0, io.SeekStart); err != nil {
log.Errorf("Seek file is failed, err: %s", err.Error())
return err
}

responseWriter := util.GetResponseFromContext(ctx)
responseWriter.Header().Set("Content-Type", "application/octet-stream")
responseWriter.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", fileName))
_, err = io.Copy(responseWriter, file)
if err != nil {
return errors.RenderInternalServerError(err)
}
return nil
}

func composeCSVFileContent(metas []*ormbmodel.Model) ([][]string, error) {
fileContents := [][]string{
{"Basic Info"},
{"Model Name"},
{"Model Source"},
{"Model Framework"},
{"Model Format"},
{"Model Inputs"},
{"Model Outputs"},
}
for index, content := range fileContents {
for _, meta := range metas {
switch content[0] {
case "Basic Info":
continue
case "Model Name":
content = append(content, meta.Path)
case "Model Source":
content = append(content, meta.Metadata.Author)
case "Model Framework":
content = append(content, meta.Metadata.Framework)
case "Model Format":
content = append(content, meta.Metadata.Format)
case "Model Inputs":
jsonString := "-"
if meta.Metadata.Signature.Inputs != nil {
jsonString = composeJSONString(meta.Metadata.Signature.Inputs)
}
content = append(content, jsonString)
case "Model Outputs":
jsonString := "-"
if meta.Metadata.Signature.Outputs != nil {
jsonString = composeJSONString(meta.Metadata.Signature.Outputs)
}
content = append(content, jsonString)
}
fileContents[index] = content
}
}

return fileContents, nil
}

func composeJSONString(obj interface{}) string {
bytes, err := json.MarshalIndent(obj, "", "\t")
if err != nil {
log.Errorf("Compose JSON string failed: %v", err.Error())
return ""
}
return string(bytes)
}
22 changes: 22 additions & 0 deletions pkg/registry/comparison/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package comparison

import (
ormbmodel "github.com/kleveross/ormb/pkg/model"

"github.com/kleveross/klever-model-registry/pkg/registry/paging"
)

type Comparison struct {
Models []ComparisonModel `json:"models"`
}

type ComparisonModel struct {
Name string `json:"name"`
Project string `json:"project"`
Tag string `json:"tag"`
}

type ORMBModelList struct {
ListMeta paging.ListMeta `json:"metadata"`
Items []*ormbmodel.Model `json:"items"`
}
4 changes: 2 additions & 2 deletions pkg/registry/harbor/artifacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (p *Proxy) createModelJob(path string, byteManifests []byte) error {
modelName := pathSlice[2]
versionName := pathSlice[4]

artis, err := p.listArtifacts(projectName, modelName)
artis, err := p.ListArtifacts(projectName, modelName)
if err != nil {
return err
}
Expand Down Expand Up @@ -85,7 +85,7 @@ func (p *Proxy) createModelJob(path string, byteManifests []byte) error {
return nil
}

func (p *Proxy) listArtifacts(project, repo string) ([]Artifact, error) {
func (p *Proxy) ListArtifacts(project, repo string) ([]Artifact, error) {
url := fmt.Sprintf("http://%v/api/v2.0/projects/%v/repositories/%v/artifacts",
p.Domain, project, repo)

Expand Down

0 comments on commit 5ceb2ed

Please sign in to comment.