Skip to content

Commit

Permalink
add model name config (openpilot-hub#21)
Browse files Browse the repository at this point in the history
* add model name config

* comboBox model
  • Loading branch information
xiangtianyu authored Dec 26, 2023
1 parent adc8b85 commit 7e450c4
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 1 deletion.
51 changes: 51 additions & 0 deletions src/main/java/com/zhongan/devpilot/enums/OpenAIModelNameEnum.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.zhongan.devpilot.enums;

public enum OpenAIModelNameEnum {
GPT3_5_TURBO("gpt-3.5-turbo", "gpt-3.5-turbo"),
GPT3_5_TURBO_16K("gpt-3.5-turbo-16k", "gpt-3.5-turbo(16k)"),
GPT4("gpt-4", "gpt-4"),
GPT4_32K("gpt-4-32k", "gpt-4(32k)"),
CUSTOM("custom", "Custom Model");

private String name;

private String displayName;

OpenAIModelNameEnum(String name, String displayName) {
this.name = name;
this.displayName = displayName;
}

public static OpenAIModelNameEnum fromName(String name) {
if (name == null) {
return GPT3_5_TURBO;
}
for (OpenAIModelNameEnum type : OpenAIModelNameEnum.values()) {
if (type.getName().equals(name)) {
return type;
}
}
return GPT3_5_TURBO;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getDisplayName() {
return displayName;
}

public void setDisplayName(String displayName) {
this.displayName = displayName;
}

@Override
public String toString() {
return displayName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ public String chatCompletion(DevPilotChatCompletionRequest chatCompletionRequest
return "Chat completion failed: host is empty";
}

var modelName = CodeLlamaSettingsState.getInstance().getModelName();

if (StringUtils.isEmpty(modelName)) {
return "Chat completion failed: code llama model name is empty";
}

chatCompletionRequest.setModel(modelName);

okhttp3.Response response;

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ public String chatCompletion(DevPilotChatCompletionRequest chatCompletionRequest
return "Chat completion failed: api key is empty";
}

chatCompletionRequest.setModel("gpt-3.5-turbo");
var modelName = OpenAISettingsState.getInstance().getModelName();

if (StringUtils.isEmpty(modelName)) {
return "Chat completion failed: openai model name is empty";
}

chatCompletionRequest.setModel(modelName);

okhttp3.Response response;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.intellij.util.ui.UI;
import com.zhongan.devpilot.enums.ModelServiceEnum;
import com.zhongan.devpilot.enums.ModelTypeEnum;
import com.zhongan.devpilot.enums.OpenAIModelNameEnum;
import com.zhongan.devpilot.settings.state.AIGatewaySettingsState;
import com.zhongan.devpilot.settings.state.CodeLlamaSettingsState;
import com.zhongan.devpilot.settings.state.DevPilotLlmSettingsState;
Expand All @@ -29,6 +30,10 @@ public class DevPilotConfigForm {

private final JBTextField openAIKeyField;

private final ComboBox<OpenAIModelNameEnum> openAIModelNameComboBox;

private final JBTextField openAICustomModelNameField;

private final JPanel aiGatewayServicePanel;

private final JBTextField aiGatewayBaseHostField;
Expand All @@ -39,6 +44,8 @@ public class DevPilotConfigForm {

private final JBTextField codeLlamaBaseHostField;

private final JBTextField codeLlamaModelNameField;

private Integer index;

public DevPilotConfigForm() {
Expand All @@ -50,6 +57,16 @@ public DevPilotConfigForm() {
var openAISettings = OpenAISettingsState.getInstance();
openAIBaseHostField = new JBTextField(openAISettings.getModelHost(), 30);
openAIKeyField = new JBTextField(openAISettings.getPrivateKey(), 30);
openAICustomModelNameField = new JBTextField(openAISettings.getCustomModelName(), 15);
var modelNameEnum = OpenAIModelNameEnum.fromName(openAISettings.getModelName());
openAICustomModelNameField.setEnabled(modelNameEnum == OpenAIModelNameEnum.CUSTOM);
openAIModelNameComboBox = new ComboBox<>(OpenAIModelNameEnum.values());
openAIModelNameComboBox.setSelectedItem(modelNameEnum);
openAIModelNameComboBox.addItemListener(e -> {
var selected = (OpenAIModelNameEnum) e.getItem();
openAICustomModelNameField.setEnabled(selected == OpenAIModelNameEnum.CUSTOM);
});

openAIServicePanel = createOpenAIServicePanel();

var aiGatewaySettings = AIGatewaySettingsState.getInstance();
Expand All @@ -68,6 +85,7 @@ public DevPilotConfigForm() {

var codeLlamaSettings = CodeLlamaSettingsState.getInstance();
codeLlamaBaseHostField = new JBTextField(codeLlamaSettings.getModelHost(), 30);
codeLlamaModelNameField = new JBTextField(codeLlamaSettings.getModelName(), 30);
codeLlamaServicePanel = createCodeLlamaServicePanel();

panelShow(selectedEnum);
Expand Down Expand Up @@ -136,6 +154,12 @@ private JPanel createOpenAIServicePanel() {
.add(UI.PanelFactory.panel(openAIKeyField)
.withLabel(DevPilotMessageBundle.get("devpilot.settings.service.apiKeyLabel"))
.resizeX(false))
.add(UI.PanelFactory.panel(openAIModelNameComboBox)
.withLabel(DevPilotMessageBundle.get("devpilot.settings.service.modelNameLabel"))
.resizeX(false))
.add(UI.PanelFactory.panel(openAICustomModelNameField)
.withLabel(DevPilotMessageBundle.get("devpilot.settings.service.customModelNameLabel"))
.resizeX(false))
.createPanel();
panel.setBorder(JBUI.Borders.emptyLeft(16));
return panel;
Expand All @@ -159,6 +183,9 @@ private JPanel createCodeLlamaServicePanel() {
.add(UI.PanelFactory.panel(codeLlamaBaseHostField)
.withLabel(DevPilotMessageBundle.get("devpilot.settings.service.modelHostLabel"))
.resizeX(false))
.add(UI.PanelFactory.panel(codeLlamaModelNameField)
.withLabel(DevPilotMessageBundle.get("devpilot.settings.service.modelNameLabel"))
.resizeX(false))
.createPanel();
panel.setBorder(JBUI.Borders.emptyLeft(16));
return panel;
Expand Down Expand Up @@ -217,4 +244,15 @@ public ModelTypeEnum getAIGatewayModel() {
return (ModelTypeEnum) aiGatewayModelComboBox.getSelectedItem();
}

public OpenAIModelNameEnum getOpenAIModelName() {
return (OpenAIModelNameEnum) openAIModelNameComboBox.getSelectedItem();
}

public String getOpenAICustomModelName() {
return openAICustomModelNameField.getText();
}

public String getCodeLlamaModelName() {
return codeLlamaModelNameField.getText();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ public boolean isModified() {
|| !selectedModel.getName().equals(settings.getSelectedModel())
|| !selectedModelType.getName().equals(aiGatewaySettings.getSelectedModel())
|| !serviceForm.getOpenAIBaseHost().equals(openAISettings.getModelHost())
|| !serviceForm.getOpenAIModelName().getName().equals(openAISettings.getModelName())
|| !serviceForm.getOpenAICustomModelName().equals(openAISettings.getCustomModelName())
|| !serviceForm.getAIGatewayBaseHost().equals(aiGatewaySettings.getModelBaseHost(selectedModelType.getName()))
|| !serviceForm.getOpenAIKey().equals(openAISettings.getPrivateKey())
|| !serviceForm.getCodeLlamaBaseHost().equals(codeLlamaSettings.getModelHost())
|| !serviceForm.getCodeLlamaModelName().equals(codeLlamaSettings.getModelName())
|| !serviceForm.getLanguageIndex().equals(languageSettings.getLanguageIndex());
}

Expand All @@ -72,11 +75,15 @@ public void apply() throws ConfigurationException {
var serviceForm = settingsComponent.getDevPilotConfigForm();
var selectedModel = serviceForm.getSelectedModel();
var selectedModelType = serviceForm.getAIGatewayModel();
var openAIModelName = serviceForm.getOpenAIModelName();

settings.setSelectedModel(selectedModel.getName());
openAISettings.setModelHost(serviceForm.getOpenAIBaseHost());
openAISettings.setPrivateKey(serviceForm.getOpenAIKey());
openAISettings.setModelName(openAIModelName.getName());
openAISettings.setCustomModelName(serviceForm.getOpenAICustomModelName());
codeLlamaSettings.setModelHost(serviceForm.getCodeLlamaBaseHost());
codeLlamaSettings.setModelName(serviceForm.getCodeLlamaModelName());
aiGatewaySettings.setModelBaseHost(selectedModelType.getName(), serviceForm.getAIGatewayBaseHost());
aiGatewaySettings.setSelectedModel(selectedModelType.getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
public class CodeLlamaSettingsState implements PersistentStateComponent<CodeLlamaSettingsState> {
private String modelHost;

private String modelName = "CodeLlama-7b";

public static CodeLlamaSettingsState getInstance() {
return ApplicationManager.getApplication().getService(CodeLlamaSettingsState.class);
}
Expand All @@ -25,6 +27,14 @@ public void setModelHost(String modelHost) {
this.modelHost = modelHost;
}

public String getModelName() {
return modelName;
}

public void setModelName(String modelName) {
this.modelName = modelName;
}

@Override
public @Nullable CodeLlamaSettingsState getState() {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import com.intellij.openapi.components.State;
import com.intellij.openapi.components.Storage;
import com.intellij.util.xmlb.XmlSerializerUtil;
import com.zhongan.devpilot.enums.OpenAIModelNameEnum;

@State(name = "DevPilot_OpenAISettings", storages = @Storage("DevPilot_OpenAISettings.xml"))
public class OpenAISettingsState implements PersistentStateComponent<OpenAISettingsState> {
private String modelHost;

private String privateKey;

private String modelName = OpenAIModelNameEnum.GPT3_5_TURBO.getName();

private String customModelName;

public static OpenAISettingsState getInstance() {
return ApplicationManager.getApplication().getService(OpenAISettingsState.class);
}
Expand All @@ -32,6 +37,22 @@ public void setPrivateKey(String privateKey) {
this.privateKey = privateKey;
}

public String getModelName() {
return modelName;
}

public void setModelName(String modelName) {
this.modelName = modelName;
}

public String getCustomModelName() {
return customModelName;
}

public void setCustomModelName(String customModelName) {
this.customModelName = customModelName;
}

@Override
public OpenAISettingsState getState() {
return this;
Expand Down
2 changes: 2 additions & 0 deletions src/main/resources/messages/devpilot_en.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
devpilot.settings.service.title=Service Configuration
devpilot.settings.service.modelHostLabel=Model Host
devpilot.settings.service.modelTypeLabel=Choose Model
devpilot.settings.service.modelNameLabel=Model Name
devpilot.settings.service.customModelNameLabel=Custom Model Name
devpilot.settings.service.apiKeyLabel=API Key
devpilot.settings=DevPilot Settings
notification.group.devpilot=DevPilot Notification Group
Expand Down
2 changes: 2 additions & 0 deletions src/main/resources/messages/devpilot_zh.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
devpilot.settings.service.title=\u670D\u52A1\u914D\u7F6E
devpilot.settings.service.modelHostLabel=\u6a21\u578b Host
devpilot.settings.service.modelTypeLabel=\u9009\u62e9\u6a21\u578b
devpilot.settings.service.modelNameLabel=\u6a21\u578b\u540d\u79f0
devpilot.settings.service.customModelNameLabel=\u81ea\u5b9a\u4e49\u6a21\u578b\u540d
devpilot.settings.service.apiKeyLabel=API \u5bc6\u94a5
devpilot.settings=DevPilot\u8BBE\u7F6E
notification.group.devpilot=DevPilot\u901A\u77E5\u7EC4
Expand Down

0 comments on commit 7e450c4

Please sign in to comment.