From 4d700be9ae65059476ac781cab44302506a8e5d3 Mon Sep 17 00:00:00 2001 From: FelixJiang Date: Sat, 1 Apr 2023 20:16:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=B8=87=E7=94=A8=E8=A1=A8?= =?UTF-8?q?=E6=A0=BCPrompt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/service/system/sys_chatgpt.go | 101 ++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 9 deletions(-) diff --git a/server/service/system/sys_chatgpt.go b/server/service/system/sys_chatgpt.go index ffe2d37c..b89e671a 100644 --- a/server/service/system/sys_chatgpt.go +++ b/server/service/system/sys_chatgpt.go @@ -2,7 +2,6 @@ package system import ( "context" - "encoding/json" "errors" "fmt" "github.com/flipped-aurora/gin-vue-admin/server/global" @@ -10,6 +9,7 @@ import ( "github.com/flipped-aurora/gin-vue-admin/server/model/system/request" "github.com/sashabaranov/go-openai" "gorm.io/gorm" + "strings" ) type ChatGptService struct{} @@ -43,10 +43,11 @@ func (chat *ChatGptService) GetTable(req request.ChatGptRequest) (sql string, re return "", nil, errors.New("未选择db") } var tablesInfo []system.ChatField + var tableName string global.GVA_DB.Table("information_schema.columns").Where("TABLE_SCHEMA = ?", req.DBName).Scan(&tablesInfo) - b, err := json.Marshal(tablesInfo) - if err != nil { - return + + for i := range tablesInfo { + tableName += tablesInfo[i].TABLE_NAME + "," } option, err := chat.GetSK() if err != nil { @@ -55,12 +56,44 @@ func (chat *ChatGptService) GetTable(req request.ChatGptRequest) (sql string, re client := openai.NewClient(option.SK) ctx := context.Background() + tables, err := getTables(ctx, client, tableName, req.Chat) + if err != nil { + return "", nil, err + } + tableArr := strings.Split(tables, ",") + + sql, err = getSql(ctx, client, tableArr, tablesInfo, req.Chat) + if err != nil { + return "", nil, err + } + err = global.GVA_DB.Raw(sql).Scan(&results).Error + return sql, results, err +} + +func getTables(ctx context.Context, client *openai.Client, tables string, chat string) (string, error) { + var tablePrompt = `You are a database administrator + +If I want to query at least those tables, I will provide you with the following table configuration information: + +Table 1, Table 2, Table 3 + +Please return the table name I need according to the input format + +Please do not return information other than the table + +Configured as: + +%s + +The problem is: +%s +` chatReq := openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + Model: openai.GPT3TextDavinci003, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, - Content: fmt.Sprintf("数据库所有字段用json表示,表名为TABLE_NAME,列名为COLUMN_NAME,列描述为COLUMN_COMMENT,%s,根据语句帮我生成单纯的查询sql,,不要提示语\n+%s", string(b), req.Chat), + Content: fmt.Sprintf(tablePrompt, tables, chat), }, }, } @@ -68,9 +101,59 @@ func (chat *ChatGptService) GetTable(req request.ChatGptRequest) (sql string, re resp, err := client.CreateChatCompletion(ctx, chatReq) if err != nil { fmt.Printf("Completion error: %v\n", err) - return + return "", err + } + return resp.Choices[0].Message.Content, nil +} + +func getSql(ctx context.Context, client *openai.Client, tables []string, ChatField []system.ChatField, chat string) (string, error) { + var sqlPrompt = `You are a database administrator + +Give me an SQL statement based on my question + +I will provide you with my current database table configuration information in the form below + +Table Name | Column Name | Column Description + +Do not return information other than SQL + +Configured as: + +user | username | 用户名 +user | sex | 性别 +user | age | 年龄 +user | pwd | 密码 + +The problem is: + +%s` + var configured string + + var tablesMap = make(map[string]bool) + for i := range tables { + tablesMap[tables[i]] = true } - err = global.GVA_DB.Raw(resp.Choices[0].Message.Content).Scan(&results).Error - return resp.Choices[0].Message.Content, results, err + for i := range ChatField { + if tablesMap[ChatField[i].TABLE_NAME] { + configured += fmt.Sprintf("%s | %s | %s \n", ChatField[i].TABLE_NAME, ChatField[i].COLUMN_NAME, ChatField[i].COLUMN_COMMENT) + } + } + + chatReq := openai.ChatCompletionRequest{ + Model: openai.GPT3TextDavinci003, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: fmt.Sprintf(sqlPrompt, configured, chat), + }, + }, + } + + resp, err := client.CreateChatCompletion(ctx, chatReq) + if err != nil { + fmt.Printf("Completion error: %v\n", err) + return "", err + } + return resp.Choices[0].Message.Content, nil }