diff --git a/ui/apiserver/apiserver.go b/ui/apiserver/apiserver.go index 2ea8d06..6de1357 100644 --- a/ui/apiserver/apiserver.go +++ b/ui/apiserver/apiserver.go @@ -336,7 +336,7 @@ func (api *ApiServer) runIlabChatCommand(question, context string) (string, erro cmd = exec.Command("echo", cmdArgs...) api.logger.Infof("Running in test mode: %s", commandStr) } else { - modelName, err := api.fetchModelName(true) + modelName, err := api.fetchModelName(true, api.preCheckEndpointURL) if err != nil { api.logger.Errorf("Failed to fetch model name: %v", err) return "failed to retrieve the model name", err @@ -382,9 +382,8 @@ func setupLogger(debugMode bool) *zap.SugaredLogger { // fetchModelName hits the defined precheck endpoint with "/models" appended to extract the model name. // If fullName is true, it returns the entire ID value; if false, it returns the parsed out name after the double hyphens. -func (api *ApiServer) fetchModelName(fullName bool) (string, error) { +func (api *ApiServer) fetchModelName(fullName bool, endpoint string) (string, error) { // Ensure the endpoint URL ends with "/models" - endpoint := api.preCheckEndpointURL if !strings.HasSuffix(endpoint, "/") { endpoint += "/" } diff --git a/worker/cmd/generate.go b/worker/cmd/generate.go index 7e4647b..297619e 100644 --- a/worker/cmd/generate.go +++ b/worker/cmd/generate.go @@ -35,23 +35,24 @@ import ( ) var ( - WorkDir string - VenvDir string - PreCheckEndpointURL string - SdgEndpointURL string - NumInstructions int - GitRemote string - Origin string - GithubUsername string - GithubToken string - S3Bucket string - AWSRegion string - TlsClientCertPath string - TlsClientKeyPath string - TlsServerCaCertPath string - TlsInsecure bool - MaxSeed int - TaxonomyFolders = []string{"compositional_skills", "knowledge"} + WorkDir string + VenvDir string + PreCheckEndpointURL string + PreCheckScoringEndpointURL string + SdgEndpointURL string + NumInstructions int + GitRemote string + Origin string + GithubUsername string + GithubToken string + S3Bucket string + AWSRegion string + TlsClientCertPath string + TlsClientKeyPath string + TlsServerCaCertPath string + TlsInsecure bool + MaxSeed int + TaxonomyFolders = []string{"compositional_skills", "knowledge"} ) const ( @@ -76,35 +77,37 @@ const ( // Worker encapsulates dependencies and methods to process jobs type Worker struct { - ctx context.Context - pool *redis.Pool - svc *s3.Client - logger *zap.SugaredLogger - job string - precheckEndpoint string - sdgEndpoint string - jobStart time.Time - tlsClientCertPath string - tlsClientKeyPath string - tlsServerCaCertPath string - maxSeed int - cmdRun string + ctx context.Context + pool *redis.Pool + svc *s3.Client + logger *zap.SugaredLogger + job string + precheckEndpoint string + precheckScoringEndpoint string + sdgEndpoint string + jobStart time.Time + tlsClientCertPath string + tlsClientKeyPath string + tlsServerCaCertPath string + maxSeed int + cmdRun string } -func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker { +func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, precheckScoringEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker { return &Worker{ - ctx: ctx, - pool: pool, - svc: svc, - logger: logger, - job: job, - precheckEndpoint: precheckEndpoint, - sdgEndpoint: sdgEndpoint, - jobStart: time.Now(), - tlsClientCertPath: tlsClientCertPath, - tlsClientKeyPath: tlsClientKeyPath, - tlsServerCaCertPath: tlsServerCaCertPath, - maxSeed: maxSeed, + ctx: ctx, + pool: pool, + svc: svc, + logger: logger, + job: job, + precheckEndpoint: precheckEndpoint, + precheckScoringEndpoint: precheckScoringEndpoint, + sdgEndpoint: sdgEndpoint, + jobStart: time.Now(), + tlsClientCertPath: tlsClientCertPath, + tlsClientKeyPath: tlsClientKeyPath, + tlsServerCaCertPath: tlsServerCaCertPath, + maxSeed: maxSeed, } } @@ -118,6 +121,7 @@ func init() { generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in") generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory") generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.") + generateCmd.Flags().StringVarP(&PreCheckScoringEndpointURL, "precheck-scoring-endpoint-url", "", PreCheckEndpointURL, "Endpoint hosting the model API that will be scoring the output of precheck against the answers supplied in the PR. Default, it assumes the model is the same as precheck model and is served locally.") generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.") generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate") generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo") @@ -190,6 +194,7 @@ var generateCmd = &cobra.Command{ } NewJobProcessor(ctx, pool, svc, sugar, job, PreCheckEndpointURL, + PreCheckScoringEndpointURL, SdgEndpointURL, TlsClientCertPath, TlsClientKeyPath, @@ -211,12 +216,100 @@ var generateCmd = &cobra.Command{ }, } +func (w *Worker) runPrecheckScoring(precheckPRAnswers []string, precheckEndpointAnswers []string, precheckPRQuestions []string, lab string, outputDir string, preCheckScoringModelName string) error { + if len(precheckPRAnswers) != len(precheckEndpointAnswers) { + errMsg := "PR answers and Endpoint answers returned a different number of entries, something went wrong" + w.logger.Error(errMsg) + return fmt.Errorf(errMsg) + } + + workDir := "." + if WorkDir != "" { + workDir = WorkDir + } + combinedYAMLScoringPath := path.Join(outputDir, "combined_chatlog_scoring.yaml") + + type QuestionScore struct { + Question string + HumanAnswer string + EndpointAnswer string + Score string + } + + type QuestionScoreReport struct { + RunTime string + QuestionScores []QuestionScore + } + + yamlData := QuestionScoreReport{} + for i := 0; i < len(precheckPRAnswers); i++ { + err, promptTemplate := generatePrecheckScoringPrompt(precheckPRAnswers[i], precheckEndpointAnswers[i], precheckPRQuestions[i]) + if err != nil { + w.logger.Errorf("Failed to generate a prompt for precheck scorring: %v", err) + return err + } + + commandStr := fmt.Sprintf("chat --quick-question %s", promptTemplate) + if TlsInsecure { + commandStr += " --tls-insecure" + } + if PreCheckScoringEndpointURL != localEndpoint && preCheckScoringModelName != "unknown" { + commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckScoringEndpointURL, preCheckScoringModelName) + } + cmdArgs := strings.Fields(commandStr) + cmd := exec.Command(lab, cmdArgs...) + // Register the command for reporting/logging + w.cmdRun = cmd.String() + w.logger.Infof("Running the precheck scoring command: %s", cmd.String()) + + cmd.Dir = workDir + cmd.Env = os.Environ() + var out bytes.Buffer + var errOut bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &errOut + err = cmd.Run() + if err != nil { + w.logger.Errorf("Precheck scoring command failed with error: %v; stderr: %s", err, errOut.String()) + continue + } + + questionScore := QuestionScore{ + Question: precheckPRQuestions[i], + HumanAnswer: precheckPRAnswers[i], + EndpointAnswer: precheckEndpointAnswers[i], + Score: out.String(), + } + yamlData.QuestionScores = append(yamlData.QuestionScores, questionScore) + + } + + yamlData.RunTime = time.Now().Format("2006-01-02T15_04_05") + + scoringYaml, err := yaml.Marshal(yamlData) + if err != nil { + w.logger.Errorf("Could not marshal scoring data to YAML: %v", err) + return err + } + + err = os.WriteFile(combinedYAMLScoringPath, scoringYaml, 0644) + if err != nil { + w.logger.Errorf("Could not write chatlog to file: %v", err) + return err + } + + return nil +} + // runPrecheck runs lab chat against git diffed yaml files -func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { +func (w *Worker) runPrecheck(lab, outputDir, modelName string) (error, []string, []string, []string) { workDir := "." if WorkDir != "" { workDir = WorkDir } + precheckPRAnswers := []string{} + precheckEndpointAnswers := []string{} + precheckPRQuestions := []string{} chatlogDir := path.Join(workDir, "data", "chatlogs") combinedYAMLPath := path.Join(outputDir, "combined_chatlogs.yaml") combinedLogPath := path.Join(outputDir, "combined_chatlogs.log") @@ -297,19 +390,19 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { stdout, err := cmd.StdoutPipe() if err != nil { w.logger.Errorf("Could not get stdout pipe: %v", err) - return err + return err, []string{}, []string{}, []string{} } w.logger.Debug("Running ilab diff") if err := cmd.Start(); err != nil { w.logger.Errorf("Could not start command(%s %s): %v", cmd.Path, strings.Join(cmd.Args, " "), err) - return err + return err, []string{}, []string{}, []string{} } output, err := io.ReadAll(stdout) if err != nil { w.logger.Errorf("Could not read stdout: %v", err) - return err + return err, []string{}, []string{}, []string{} } outputStr := string(output) w.logger.Debugf("Output: %s", outputStr) @@ -327,7 +420,7 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { if yamlFileCount == 0 { errMsg := "No modified YAML files detected in the PR for precheck" w.logger.Error(errMsg) - return fmt.Errorf(errMsg) + return fmt.Errorf(errMsg), []string{}, []string{}, []string{} } // Proceed with YAML files processing if they exist @@ -340,14 +433,14 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { f, err := os.Open(filePath) if err != nil { w.logger.Errorf("Could not open taxonomy file: %v", err) - return err + return err, []string{}, []string{}, []string{} } defer f.Close() content, err := io.ReadAll(f) if err != nil { w.logger.Error(err) - return err + return err, []string{}, []string{}, []string{} } var data map[string]interface{} @@ -356,15 +449,16 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { // Odds are, the PR was not yaml-linted since it's invalid YAML failing unmarshalling err = fmt.Errorf("the original taxonomy YAML likely did not pass yaml-linting, here is the unmarshalling error: %v", err) w.logger.Error(err) - return err + return err, []string{}, []string{}, []string{} } // Check if "seed_examples" exists and is a list + seedExamples, ok := data["seed_examples"].([]interface{}) if !ok { err = fmt.Errorf("seed_examples not found or not a list") w.logger.Error(err) - return err + return err, []string{}, []string{}, []string{} } for _, item := range seedExamples { @@ -378,6 +472,12 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { w.logger.Error("Question not found or not a string") continue } + answer, ok := example["answer"].(string) + if !ok { + w.logger.Error("Question not found or not a string") + continue + } + precheckPRAnswers = append(precheckPRAnswers, answer) context, hasContext := example["context"].(string) originalQuestion := question @@ -418,6 +518,9 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { "output": out.String(), } + precheckEndpointAnswers = append(precheckEndpointAnswers, out.String()) + precheckPRQuestions = append(precheckPRQuestions, originalQuestion) + if hasContext { logData["input"].(map[string]string)["context"] = context } @@ -450,7 +553,8 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { time.Sleep(1 * time.Second) } } - return nil + return nil, precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions + // return nil, precheckPRAnswers, precheckPRQuestions } // processJob processes a given job, all jobs start here @@ -533,7 +637,7 @@ func (w *Worker) processJob() { // sdg-svc does not have a models endpoint as yet if jobType != jobSDG && PreCheckEndpointURL != localEndpoint { var err error - modelName, err = w.fetchModelName(true) + modelName, err = w.fetchModelName(true, w.precheckEndpoint) if err != nil { w.logger.Errorf("Failed to fetch model name: %v", err) modelName = "unknown" @@ -572,12 +676,32 @@ func (w *Worker) processJob() { case jobPreCheck: // @instructlab-bot precheck // Runs precheck on a backend node - err = w.runPrecheck(lab, outputDir, modelName) + err, precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions := w.runPrecheck(lab, outputDir, modelName) if err != nil { sugar.Errorf("Could not run precheck: %v", err) w.reportJobError(err) return } + + var scoringModelName string + + if jobType == jobPreCheck && w.precheckScoringEndpoint != localEndpoint { + var err error + scoringModelName, err = w.fetchModelName(true, w.precheckScoringEndpoint) + if err != nil { + w.logger.Errorf("Failed to fetch model name: %v", err) + scoringModelName = "unknown" + } + } else { + scoringModelName = w.getModelNameFromConfig() // will default to standard precheck model + } + + err = w.runPrecheckScoring(precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions, lab, outputDir, scoringModelName) + if err != nil { + sugar.Errorf("Could not run scoring on result of precheck: %v", err) + w.reportJobError(err) + return + } case jobSDG: // @instructlab-bot generate // Runs generate on the SDG backend @@ -864,9 +988,8 @@ func (w *Worker) getModelNameFromConfig() string { // fetchModelName hits the defined precheckEndpoint with "/models" appended to extract the model name. // If fullName is true, it returns the entire ID value; if false, it returns the parsed out name after the double hyphens. -func (w *Worker) fetchModelName(fullName bool) (string, error) { +func (w *Worker) fetchModelName(fullName bool, endpoint string) (string, error) { // Ensure the endpoint URL ends with "/models" - endpoint := w.precheckEndpoint if !strings.HasSuffix(endpoint, "/") { endpoint += "/" } @@ -962,7 +1085,7 @@ func (w *Worker) determineModelName(jobType string) string { // precheck is the only case we use a remote OpenAI endpoint right now if PreCheckEndpointURL != localEndpoint && jobType == jobPreCheck { - modelName, err := w.fetchModelName(false) + modelName, err := w.fetchModelName(false, w.precheckEndpoint) if err != nil { w.logger.Errorf("Failed to fetch model name: %v", err) return "unknown" diff --git a/worker/cmd/generate_test.go b/worker/cmd/generate_test.go index 6102c18..53f6bb0 100644 --- a/worker/cmd/generate_test.go +++ b/worker/cmd/generate_test.go @@ -153,6 +153,7 @@ func TestFetchModelName(t *testing.T) { zap.NewExample().Sugar(), "job-id", mockServer.URL, + mockServer.URL, "http://sdg-example.com", "dummy-client-cert-path.pem", "dummy-client-key-path.pem", @@ -160,12 +161,12 @@ func TestFetchModelName(t *testing.T) { 20, ) - modelName, err := w.fetchModelName(false) + modelName, err := w.fetchModelName(false, w.precheckEndpoint) assert.NoError(t, err, "fetchModelName should not return an error") expectedModelName := "Mixtral-8x7B-Instruct-v0.1" assert.Equal(t, expectedModelName, modelName, "The model name should be extracted correctly") - modelName, err = w.fetchModelName(true) + modelName, err = w.fetchModelName(true, w.precheckEndpoint) assert.NoError(t, err, "fetchModelName should not return an error") expectedModelName = "/shared_model_storage/transformers_cache/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/5c79a376139be989ef1838f360bf4f1f256d7aec" assert.Equal(t, expectedModelName, modelName, "The model name should be extracted correctly") @@ -214,13 +215,14 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) { zap.NewExample().Sugar(), "job-id", mockServer.URL, + mockServer.URL, "http://sdg-example.com", "dummy-client-cert-path.pem", "dummy-client-key-path.pem", "dummy-ca-cert-path.pem", 20, ) - modelName, err := w.fetchModelName(false) + modelName, err := w.fetchModelName(false, w.precheckEndpoint) // Verify that an error was returned due to the invalid "object" field assert.Error(t, err, "fetchModelName should return an error for invalid object field") diff --git a/worker/cmd/templates.go b/worker/cmd/templates.go index 43b4946..a58b960 100644 --- a/worker/cmd/templates.go +++ b/worker/cmd/templates.go @@ -1,6 +1,7 @@ package cmd import ( + "bytes" "context" "encoding/json" "fmt" @@ -264,3 +265,48 @@ func generateFormattedYAML(ctx context.Context, outputDir, filename string, svc return s3Key } + +func generatePrecheckScoringPrompt(precheckPRAnswer string, precheckEndpointAnswer string, precheckQuestion string) (error, string) { + promptTemplate := ` + Evaluate and compare the quality of the below ### Model answer compared to the ### Human answer when given the same ### Question provided below. + The ### Human answer is to be treated as the ground truth answer. + Assign a score using the following 3 point scale: + 1: It means that the answers are identical or nearly identical, based on both the content of the two provided answers as + well as the wording and details of the answer provided. + + 2: It means that there is moderate variation in the answers. The two provided answers could have a moderately different sentence structure + and wording, or have some differences in the content or perspective, but still share some key points. + + 3: It means the answers are significantly different. The two provided answers differ greatly in wording and perspective or have very different + or contridictory facts and content. + + ### Question: + "{{ .Question }}" + ### Human answer: + "{{ .HumanAnswer }}" + ### Model answer: + "{{ .ModelAnswer }}" + + ` + + tmpl, err := template.New("modelScoring").Parse(promptTemplate) + if err != nil { + return fmt.Errorf("error parsing modelScoring prompt template: %w", err), "" + } + + data := struct { + HumanAnswer string + ModelAnswer string + Question string + }{ + HumanAnswer: precheckPRAnswer, + ModelAnswer: precheckEndpointAnswer, + Question: precheckQuestion, + } + var buf bytes.Buffer + err = tmpl.Execute(&buf, data) + if err != nil { + return fmt.Errorf("error executing modelScoring prompt template: %w", err), "" + } + return nil, buf.String() +}