Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh committed Dec 29, 2023
2 parents 81919db + f831a1f commit ab63b5b
Show file tree
Hide file tree
Showing 51 changed files with 924 additions and 604 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ checkstyle {
}

dependencies {
implementation("ee.carlrobert:llm-client:0.1.3")
implementation("ee.carlrobert:llm-client:0.2.0")
}

tasks {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ public String buildPromptWithContext(String prompt) {
}

public List<Item<Object, double[]>> createEmbeddings(
List<CheckedFile> checkedFiles,
List<ReferencedFile> referencedFiles,
@Nullable ProgressIndicator indicator) {
var words = new ArrayList<Item<Object, double[]>>();
for (int i = 0; i < checkedFiles.size(); i++) {
for (int i = 0; i < referencedFiles.size(); i++) {
try {
var checkedFile = checkedFiles.get(i);
addEmbeddings(checkedFile, words);
var referencedFile = referencedFiles.get(i);
addEmbeddings(referencedFile, words);

if (indicator != null) {
indicator.setFraction((double) i / checkedFiles.size());
indicator.setFraction((double) i / referencedFiles.size());
}
} catch (Throwable t) {
// ignore
Expand Down Expand Up @@ -101,24 +101,26 @@ private String getSearchQuery(String userPrompt) throws JsonProcessingException
.getContent();
}

private void addEmbeddings(CheckedFile checkedFile, List<Item<Object, double[]>> prevEmbeddings) {
var fileExtension = checkedFile.getFileExtension();
private void addEmbeddings(
ReferencedFile referencedFile,
List<Item<Object, double[]>> prevEmbeddings) {
var fileExtension = referencedFile.getFileExtension();
var codeSplitter = SplitterFactory.getCodeSplitter(fileExtension);
if (codeSplitter != null) {
var chunks = codeSplitter.split(
checkedFile.getFileName(),
checkedFile.getFileContent());
referencedFile.getFileName(),
referencedFile.getFileContent());
var embeddings = openAIClient.getEmbeddings(chunks);
for (int i = 0; i < chunks.size(); i++) {
prevEmbeddings.add(
new Word(chunks.get(i), checkedFile.getFileName(), normalize(embeddings.get(i))));
new Word(chunks.get(i), referencedFile.getFileName(), normalize(embeddings.get(i))));
}
} else {
var chunks = splitText(checkedFile.getFileContent(), 400);
var chunks = splitText(referencedFile.getFileContent(), 400);
var embeddings = getEmbeddings(chunks);
for (int i = 0; i < chunks.size(); i++) {
prevEmbeddings.add(
new Word(chunks.get(i), checkedFile.getFileName(), normalize(embeddings.get(i))));
new Word(chunks.get(i), referencedFile.getFileName(), normalize(embeddings.get(i))));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class CheckedFile {
public class ReferencedFile {

private final String fileName;
private final String filePath;
private final String fileContent;

public CheckedFile(File file) {
public ReferencedFile(File file) {
this.fileName = file.getName();
this.filePath = file.getPath();
try {
Expand All @@ -23,7 +24,7 @@ public CheckedFile(File file) {
}
}

public CheckedFile(String fileName, String filePath, String fileContent) {
public ReferencedFile(String fileName, String filePath, String fileContent) {
this.fileName = fileName;
this.filePath = filePath;
this.fileContent = fileContent;
Expand All @@ -50,4 +51,22 @@ public String getFileExtension() {
}
return "";
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

ReferencedFile that = (ReferencedFile) o;
return Objects.equals(filePath, that.filePath);
}

@Override
public int hashCode() {
return Objects.hash(filePath);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
You are an AI programming assistant.
Follow the user's requirements carefully & to the letter.
Your responses should be informative and logical.
You should always adhere to technical information.
If the user asks for code or technical questions, you must provide code suggestions and adhere to technical information.
If the question is related to a developer, you must respond with content related to a developer.
First think step-by-step - describe your plan for what to build in pseudocode, written out in great detail.
Then output the code in a single code block.
Minimize any other prose.
Keep your answers short and impersonal.
Use Markdown formatting in your answers.
Make sure to include the programming language name at the start of the Markdown code blocks.
Avoid wrapping the whole response in triple backticks.
The user works in an IDE built by JetBrains which has a concept for editors with open files, integrated unit test support, and output pane that shows the output of running the code as well as an integrated terminal.
You can only give one reply for each conversation turn.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
I will provide you with a snippet of code that is causing a compilation error.
Your task is to identify the potential causes of the compilation error(s) and propose code solutions to fix them.
Please approach this step by step, explaining your reasoning as you go.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Write a short and descriptive git commit message for the following git diff.
Use imperative mood, present tense, active voice and verbs.
Your entire response will be passed directly into git commit.
4 changes: 2 additions & 2 deletions src/main/java/ee/carlrobert/codegpt/CodeGPTKeys.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package ee.carlrobert.codegpt;

import com.intellij.openapi.util.Key;
import ee.carlrobert.embedding.CheckedFile;
import ee.carlrobert.embedding.ReferencedFile;
import java.util.List;

public class CodeGPTKeys {

public static final Key<List<CheckedFile>> SELECTED_FILES = Key.create("selectedFiles");
public static final Key<List<ReferencedFile>> SELECTED_FILES = Key.create("selectedFiles");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package ee.carlrobert.codegpt;

import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERRORS;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

import com.intellij.compiler.CompilerMessageImpl;
import com.intellij.notification.NotificationAction;
import com.intellij.notification.NotificationType;
import com.intellij.openapi.compiler.CompilationStatusListener;
import com.intellij.openapi.compiler.CompileContext;
import com.intellij.openapi.compiler.CompilerMessage;
import com.intellij.openapi.compiler.CompilerMessageCategory;
import com.intellij.openapi.project.Project;
import ee.carlrobert.codegpt.completions.CompletionRequestProvider;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowContentManager;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.embedding.ReferencedFile;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import org.jetbrains.annotations.NotNull;

public class ProjectCompilationStatusListener implements CompilationStatusListener {

private final Project project;

public ProjectCompilationStatusListener(Project project) {
this.project = project;
}

@Override
public void compilationFinished(
boolean aborted,
int errors,
int warnings,
@NotNull CompileContext compileContext) {
var configuration = ConfigurationState.getInstance();
var success = !configuration.isCaptureCompileErrors()
|| (!aborted && errors == 0 && warnings == 0);
if (success) {
return;
}
if (errors > 0) {
OverlayUtil.getDefaultNotification(
CodeGPTBundle.get("notification.compilationError.description"),
NotificationType.INFORMATION)
.addAction(NotificationAction.createSimpleExpiring(
CodeGPTBundle.get("notification.compilationError.okLabel"),
() -> project.getService(StandardChatToolWindowContentManager.class)
.sendMessage(getMultiFileMessage(compileContext), FIX_COMPILE_ERRORS)))
.addAction(NotificationAction.createSimpleExpiring(
CodeGPTBundle.get("checkForUpdatesTask.notification.hideButton"),
() -> ConfigurationState.getInstance().setCaptureCompileErrors(false)))
.notify(project);
}
}

private Message getMultiFileMessage(CompileContext compileContext) {
var errorMapping = getErrorMapping(compileContext);
var prompt = errorMapping.values().stream()
.flatMap(Collection::stream)
.collect(joining("\n\n"));

var message = new Message("Fix the following compile errors:\n\n" + prompt);
message.setReferencedFilePaths(errorMapping.keySet().stream()
.map(ReferencedFile::getFilePath)
.collect(toList()));
message.setUserMessage(message.getPrompt());
message.setPrompt(CompletionRequestProvider.getPromptWithContext(
new ArrayList<>(errorMapping.keySet()),
prompt));
return message;
}

private HashMap<ReferencedFile, List<String>> getErrorMapping(CompileContext compileContext) {
var errorMapping = new HashMap<ReferencedFile, List<String>>();
for (var compilerMessage : compileContext.getMessages(CompilerMessageCategory.ERROR)) {
var key = new ReferencedFile(new File(compilerMessage.getVirtualFile().getPath()));
var prevValue = errorMapping.get(key);
if (prevValue == null) {
prevValue = new ArrayList<>();
}
prevValue.add(getCompilerErrorDetails(compilerMessage));
errorMapping.put(key, prevValue);
}
return errorMapping;
}

private String getCompilerErrorDetails(CompilerMessage compilerMessage) {
if (compilerMessage instanceof CompilerMessageImpl) {
return format(
"%s:%d:%d - `%s`",
compilerMessage.getVirtualFile().getName(),
((CompilerMessageImpl) compilerMessage).getLine(),
((CompilerMessageImpl) compilerMessage).getColumn(),
compilerMessage.getMessage());
}
return format(
"%s - `%s`",
compilerMessage.getVirtualFile().getName(),
compilerMessage.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,13 @@
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.Icons;
import ee.carlrobert.codegpt.completions.CompletionClientProvider;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.io.BufferedReader;
import java.io.File;
Expand All @@ -59,7 +55,7 @@ public GenerateGitCommitMessageAction() {
public void update(@NotNull AnActionEvent event) {
var selectedService = SettingsState.getInstance().getSelectedService();
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE) {
var filesSelected = !getCheckedFilePaths(event).isEmpty();
var filesSelected = !getReferencedFilePaths(event).isEmpty();
var callAllowed = (selectedService == ServiceType.OPENAI
&& OpenAICredentialsManager.getInstance().isApiKeySet())
|| (selectedService == ServiceType.AZURE
Expand All @@ -82,7 +78,7 @@ public void actionPerformed(@NotNull AnActionEvent event) {
return;
}

var gitDiff = getGitDiff(project, getCheckedFilePaths(event));
var gitDiff = getGitDiff(project, getReferencedFilePaths(event));
var tokenCount = encodingManager.countTokens(gitDiff);
if (tokenCount > 4096 && OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) {
return;
Expand All @@ -91,25 +87,8 @@ public void actionPerformed(@NotNull AnActionEvent event) {
var editor = getCommitMessageEditor(event);
if (editor != null) {
((EditorEx) editor).setCaretVisible(false);
generateMessage(project, editor, gitDiff);
}
}

private void generateMessage(Project project, Editor editor, String gitDiff) {
var request = new OpenAIChatCompletionRequest.Builder(List.of(
new OpenAIChatCompletionMessage("system",
ConfigurationState.getInstance().getCommitMessagePrompt()),
new OpenAIChatCompletionMessage("user", gitDiff)))
.setModel(OpenAISettingsState.getInstance().getModel())
.build();
var selectedService = SettingsState.getInstance().getSelectedService();
if (selectedService == ServiceType.OPENAI) {
CompletionClientProvider.getOpenAIClient()
.getChatCompletion(request, getEventListener(project, editor.getDocument()));
}
if (selectedService == ServiceType.AZURE) {
CompletionClientProvider.getAzureClient()
.getChatCompletion(request, getEventListener(project, editor.getDocument()));
CompletionRequestService.getInstance()
.generateCommitMessageAsync(gitDiff, getEventListener(project, editor.getDocument()));
}
}

Expand Down Expand Up @@ -166,7 +145,7 @@ private Process createGitDiffProcess(String projectPath, List<String> filePaths)
}
}

private @NotNull List<String> getCheckedFilePaths(AnActionEvent event) {
private @NotNull List<String> getReferencedFilePaths(AnActionEvent event) {
var changesBrowserBase = event.getData(ChangesBrowserBase.DATA_KEY);
if (changesBrowserBase == null) {
return List.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import ee.carlrobert.codegpt.ui.checkbox.PsiElementCheckboxTree;
import ee.carlrobert.codegpt.ui.checkbox.VirtualFileCheckboxTree;
import ee.carlrobert.codegpt.util.file.FileUtil;
import ee.carlrobert.embedding.CheckedFile;
import ee.carlrobert.embedding.ReferencedFile;
import java.awt.Dimension;
import java.io.IOException;
import java.nio.file.Files;
Expand Down Expand Up @@ -62,7 +62,7 @@ public void actionPerformed(@NotNull AnActionEvent e) {
throw new RuntimeException("Could not obtain file tree");
}

var totalTokensLabel = new TotalTokensLabel(checkboxTree.getCheckedFiles());
var totalTokensLabel = new TotalTokensLabel(checkboxTree.getReferencedFiles());
checkboxTree.addCheckboxTreeListener(new CheckboxTreeListener() {
@Override
public void nodeStateChanged(@NotNull CheckedTreeNode node) {
Expand All @@ -81,10 +81,10 @@ public void nodeStateChanged(@NotNull CheckedTreeNode node) {
totalTokensLabel,
checkboxTree);
if (show == OK_EXIT_CODE) {
project.putUserData(CodeGPTKeys.SELECTED_FILES, checkboxTree.getCheckedFiles());
project.putUserData(CodeGPTKeys.SELECTED_FILES, checkboxTree.getReferencedFiles());
project.getMessageBus()
.syncPublisher(IncludeFilesInContextNotifier.FILES_INCLUDED_IN_CONTEXT_TOPIC)
.filesIncluded(checkboxTree.getCheckedFiles());
.filesIncluded(checkboxTree.getReferencedFiles());
includedFilesSettings.setPromptTemplate(promptTemplateTextArea.getText());
includedFilesSettings.setRepeatableContext(repeatableContextTextArea.getText());
}
Expand All @@ -111,9 +111,9 @@ private static class TotalTokensLabel extends JBLabel {
private int fileCount;
private int totalTokens;

TotalTokensLabel(List<CheckedFile> checkedFiles) {
fileCount = checkedFiles.size();
totalTokens = calculateTotalTokens(checkedFiles);
TotalTokensLabel(List<ReferencedFile> referencedFiles) {
fileCount = referencedFiles.size();
totalTokens = calculateTotalTokens(referencedFiles);
updateText();
}

Expand Down Expand Up @@ -167,8 +167,8 @@ private void updateText() {
FileUtil.convertLongValue(totalTokens)));
}

private int calculateTotalTokens(List<CheckedFile> checkedFiles) {
return checkedFiles.stream()
private int calculateTotalTokens(List<ReferencedFile> referencedFiles) {
return referencedFiles.stream()
.mapToInt(file -> encodingManager.countTokens(file.getFileContent()))
.sum();
}
Expand Down
Loading

0 comments on commit ab63b5b

Please sign in to comment.