Skip to content

Commit

Permalink
refresh-gpt-chat v0.7.0
Browse files Browse the repository at this point in the history
* 新增变量/v1/images/edits,用于控制图片编辑调用模型
  • Loading branch information
Yanyutin753 committed Apr 8, 2024
1 parent 2299d9e commit 084bce6
Show file tree
Hide file tree
Showing 22 changed files with 293 additions and 194 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ USER root
ENV LANG C.UTF-8

# 复制JAR文件到容器的/app目录下
COPY ./target/refresh-gpt-chat-0.6.0.jar /app/app.jar
COPY ./target/refresh-gpt-chat-0.7.0.jar /app/app.jar

# 切换到/app目录
WORKDIR /app
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
</parent>
<groupId>com.yyandywt99</groupId>
<artifactId>refresh-gpt-chat</artifactId>
<version>0.6.0</version>
<version>0.7.0</version>
<name>refresh-gpt-chat</name>
<description>refresh-gpt-chat</description>
<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public ResponseEntity<String> error() {
" <title>Document</title>\n" +
"</head>\n" +
"<body>\n" +
" <p>Thanks you use refresh-gpt-chat 0.6.0</p>\n" +
" <p>Thanks you use refresh-gpt-chat 0.7.0</p>\n" +
" <p><a href=\"https://apifox.com/apidoc/shared-4b9a7517-3f80-47a1-84fc-fcf78827a04a\">详细使用文档</a></p>\n" +
" <p><a href=\"https://github.com/Yanyutin753/refresh-gpt-chat\">项目地址</a></p>\n" +
"</body>\n" +
Expand Down
197 changes: 110 additions & 87 deletions src/main/java/com/refresh/gptChat/controller/chatController.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import com.refresh.gptChat.service.tokenService;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
Expand All @@ -31,7 +29,6 @@
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
Expand Down Expand Up @@ -62,6 +59,10 @@ public class chatController {
* speech接口
*/
private static final String speechPath = "/v1/audio/speech";
/**
* edit接口
*/
private static final String editPath = "/v1/images/edits";
/**
* utf-8类型
*/
Expand Down Expand Up @@ -256,91 +257,30 @@ public ResponseEntity<Object> imageConversation(HttpServletResponse response,
if (request_url.contains(chatPath)) {
imageUrl = request_url.replace(chatPath, "");
}
if (!imageUrl.contains("oaifree")) {
String json = com.alibaba.fastjson2.JSON.toJSONString(conversation);
RequestBody requestBody = RequestBody.create(json, mediaType);
// 去除指定部分
imageUrl = request_url + imagePath;
log.info("请求image回复接口:" + imageUrl);
Request.Builder requestBuilder = new Request.Builder().url(imageUrl).post(requestBody);
headersMap.forEach(requestBuilder::addHeader);
Request streamRequest = requestBuilder.build();
try (Response resp = client.newCall(streamRequest).execute()) {
if (!resp.isSuccessful()) {
processService.imageManageUnsuccessfulResponse(refreshTokenList, resp,
refresh_token, response, conversation, imageUrl,
request_id);
} else {
// 回复image回答
outPutService.outPutImage(response, resp, conversation);
}
} catch (ResponseStatusException e) {
return new ResponseEntity<>(e.getMessage(), e.getStatus());
} catch (Exception e) {
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
}
} else {
String json = "{\n" +
" \"model\": \"" + (image_mobel != null ? image_mobel : "gpt-4") + "\",\n" +
" \"stream\": false,\n" +
" \"messages\": [\n" +
" {\n" +
" \"content\": \"" + conversation.getPrompt() + "\",\n" +
" \"role\": \"user\"\n" +
" }\n" +
" ]\n" +
"}";
RequestBody requestBody = RequestBody.create(json, mediaType);
// 去除指定部分
imageUrl = request_url.replace(chatPath, "") + chatPath;
log.info("请求image回复接口:" + imageUrl);
Request.Builder requestBuilder = new Request.Builder().url(imageUrl).post(requestBody);
headersMap.forEach(requestBuilder::addHeader);
Request streamRequest = requestBuilder.build();
try (Response resp = client.newCall(streamRequest).execute()) {
if (!resp.isSuccessful()) {
processService.imageManageUnsuccessfulResponse(refreshTokenList, resp,
refresh_token, response, conversation, imageUrl,
request_id);
} else {
String respStr = resp.body().string();
JSONObject jsonObject = new JSONObject(respStr);
String created = jsonObject.getString("created");
JSONArray choicesArray = jsonObject.getJSONArray("choices");
if (choicesArray.length() > 0) {
JSONObject firstChoice = choicesArray.getJSONObject(0);
JSONObject messageObject = firstChoice.getJSONObject("message");
String content = messageObject.getString("content");
Matcher matcher = pattern.matcher(content);
if (matcher.find()) {
String urlAndText = matcher.group(1);
String[] splitArray = urlAndText.split(" ", 2);
if (splitArray.length == 2) {
String url = splitArray[0].trim();
String reply = "```\n{ " + splitArray[1].trim() + "}\n```";
JSONObject dataObject = new JSONObject();
dataObject.put("url", url);
JSONObject newJson = new JSONObject();
newJson.put("created", created);
newJson.put("data", dataObject);
newJson.put("reply", reply);
outPutService.outPutOaifreeImage(response, newJson, conversation);
}
}
} else {
return new ResponseEntity<>(Result.error("INTERNAL SERVER ERROR"), HttpStatus.INTERNAL_SERVER_ERROR);
}
}

} catch (ResponseStatusException e) {
return new ResponseEntity<>(e.getMessage(), e.getStatus());
} catch (Exception e) {
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
String json = com.alibaba.fastjson2.JSON.toJSONString(conversation);
RequestBody requestBody = RequestBody.create(json, mediaType);
// 去除指定部分
imageUrl = request_url + imagePath;
log.info("请求image回复接口:" + imageUrl);
Request.Builder requestBuilder = new Request.Builder().url(imageUrl).post(requestBody);
headersMap.forEach(requestBuilder::addHeader);
Request streamRequest = requestBuilder.build();
try (Response resp = client.newCall(streamRequest).execute()) {
if (!resp.isSuccessful()) {
processService.imageManageUnsuccessfulResponse(refreshTokenList, resp,
refresh_token, response, conversation, imageUrl,
request_id);
} else {
// 回复image回答
outPutService.outPutImage(response, resp, conversation);
}
} catch (ResponseStatusException e) {
return new ResponseEntity<>(e.getMessage(), e.getStatus());
} catch (Exception e) {
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
}

} catch (IllegalArgumentException e) {
return new ResponseEntity<>(Result.error(e.getMessage()), HttpStatus.BAD_REQUEST);
} catch (Exception e) {
throw new RuntimeException(e);
}
return null;
}, executor);
Expand Down Expand Up @@ -438,7 +378,7 @@ public ResponseEntity<Object> AudioConversation(HttpServletResponse response,
return new ResponseEntity<>("File is too large, limit is: " + MAX_FILE_SIZE, HttpStatus.BAD_REQUEST);
}
String filename = file.getOriginalFilename();
log.info("上传文件名:" + filename + "\n上传文件名:" + file.getSize());
log.info("上传文件名:" + filename + " 上传文件名:" + file.getSize());
log.info("上传模型:" + model);
if (model == null || model.trim().isEmpty()) {
return new ResponseEntity<>("Model cannot be empty", HttpStatus.BAD_REQUEST);
Expand Down Expand Up @@ -489,6 +429,89 @@ public ResponseEntity<Object> AudioConversation(HttpServletResponse response,
return outPutService.getObjectResponseEntity(response, future);
}

/**
* 自定义v1/images/edits接口
* 请求体不是json 会报Request body is missing or not in JSON format
* Authorization token缺失 会报Authorization header is missing
* 无法请求到access_token 会报refresh_token is wrong
*
* @param response
* @param request
* @return
* @throws JSONException
* @throws IOException
*/
@PostMapping(value = "/v1/images/edits")
public ResponseEntity<Object> AudioConversation(HttpServletResponse response,
HttpServletRequest request,
@RequestPart("image") MultipartFile image,
@RequestPart("mask") MultipartFile mask,
@RequestPart("prompt") String prompt,
@RequestPart("n") String n ){
if (image.isEmpty() || mask.isEmpty()) {
return new ResponseEntity<>("Missing Image or Mask", HttpStatus.BAD_REQUEST);
}
String imageName = image.getOriginalFilename();
String maskName = mask.getOriginalFilename();
log.info("上传Image名:" + imageName + " 上传大小:" + image.getSize());
log.info("上传Mask名:" + maskName + " 上传大小:" + mask.getSize());
log.info("prompt:" + prompt);
log.info("n:" + n);
if (prompt == null || prompt.trim().isEmpty()) {
return new ResponseEntity<>("prompt cannot be empty ", HttpStatus.BAD_REQUEST);
}
if (n == null || Integer.parseInt(n) <= 0) {
return new ResponseEntity<>("n cannot be empty and n >= 1", HttpStatus.BAD_REQUEST);
}
String header = request.getHeader("Authorization");
String authorizationHeader = (header != null && !header.trim().isEmpty()) ? header.trim() : null;
CompletableFuture<ResponseEntity<Object>> future =
CompletableFuture.supplyAsync(() -> {
try {
String[] result = messageService.extractApiKeyAndRequestUrl(authorizationHeader);
String refresh_token = result[0];
String request_url = result[1];
String request_id = result[2];
String access_token = getAccess_token(refresh_token);
String editUrl = request_url;
if (request_url.contains(chatPath)) {
editUrl = request_url.replace(chatPath, "");
}
editUrl = editUrl + editPath;
Map<String, String> headersMap = tokenService.addHeader(access_token, request_id);
RequestBody imageBody = RequestBody.create(image.getBytes(),
MediaType.parse("application/octet-stream"));
RequestBody maskBody = RequestBody.create(mask.getBytes(),
MediaType.parse("application/octet-stream"));
log.info("请求image edits 回复接口:" + editUrl);
RequestBody body = new MultipartBody.Builder()
.setType(MultipartBody.FORM)
.addFormDataPart("prompt", prompt)
.addFormDataPart("n", n)
.addFormDataPart("image", imageName, imageBody)
.addFormDataPart("mask", maskName, maskBody)
.build();
Request.Builder requestBuilder = new Request.Builder()
.url(editUrl)
.post(body);
headersMap.forEach(requestBuilder::addHeader);
try (Response resp = client.newCall(requestBuilder.build()).execute()) {
log.info(resp.toString());
if (!resp.isSuccessful()) {
processService.editManageUnsuccessfulResponse(refreshTokenList, resp,
refresh_token, response, imageBody,imageName,maskBody,maskName,prompt,n, editUrl,request_id);
} else {
outPutService.outPutEdit(response, resp);
}
}
} catch (Exception e) {
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
}
return null;
}, executor);
return outPutService.getObjectResponseEntity(response, future);
}

private String getAccess_token(String refresh_token) {
String access_token = refresh_token;
boolean is_access = refresh_token.startsWith("eyJhb");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ public void initialize() {
}
System.out.println("参数self_server_uuid:" + self_server_uuid);
System.out.println();
System.out.println("----------原神refresh-gpt-chat v0.6.0启动成功------------");
System.out.println("----------原神refresh-gpt-chat v0.7.0启动成功------------");
System.out.println("1.新增oaifree作为服务商,支持refresh_token自动刷新成access_token");
System.out.println("2.新增接口**/getAccountID**,获取ChatGPT-Account-ID");
System.out.println("3.新增画图dall-e-3接口/v1/images/generations\n" +
"4.新增文字转语音接口/v1/audio/speech\"\n" +
"5.新增语言转文字接口/v1/audio/transcriptions");
System.out.println("6.新增变量image_mobel(gpt-4或gpt-4-mobile),用于控制画图接口调用模型");
System.out.println("7.重构代码,逻辑更加清晰,结构更加合理");
System.out.println("6.重构代码,逻辑更加清晰,结构更加合理");
System.out.println("7.新增变量/v1/images/edits,用于控制图片编辑调用模型");
System.out.println("URL地址:http://0.0.0.0:" + serverPort + prefix);
System.out.println("------------------------------------------------------");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,4 @@ public String[] extractApiKeyAndRequestUrl(String authorizationHeader) throws Il
System.arraycopy(tempResult, 0, finalResult, 0, Math.min(tempResult.length, 3));
return finalResult;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import com.refresh.gptChat.service.outPutService;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import org.json.JSONObject;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;

import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -127,35 +125,36 @@ public void outPutSpeech(HttpServletResponse response, Response resp, Speech con
}
}


@Override
public void outPutOaifreeImage(HttpServletResponse response, JSONObject newJson, Image conversation) {
public void outPutAudio(HttpServletResponse response, Response resp, String temModel) {
try {
response.setContentType("application/json; charset=utf-8");
String model = (conversation.getModel() != null) ? conversation.getModel() : "dell-e-3";
String model = temModel != null ? temModel : "whisper-1";
OutputStream out = new BufferedOutputStream(response.getOutputStream());

// newJson需要是JSONObject或者JSONArray
if (!(newJson instanceof JSONObject)) {
throw new IllegalArgumentException("newJson object must be an instance of JSONObject or JSONArray");
InputStream in = new BufferedInputStream(resp.body().byteStream());
// 一次拿多少数据 迭代循环
byte[] buffer = new byte[8192];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
out.flush();
}
String jsonString = newJson.toString();

byte[] data = jsonString.getBytes(StandardCharsets.UTF_8);

out.write(data);
out.flush();

log.info("使用模型:" + model + "," + newJson);
log.info("使用模型:" + model + "," + resp);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

/**
* edit接口的输出
* @param response
* @param resp
*/
@Override
public void outPutAudio(HttpServletResponse response, Response resp, String temModel) {
public void outPutEdit(HttpServletResponse response, Response resp) {
try {
response.setContentType("application/json; charset=utf-8");
String model = temModel != null ? temModel : "whisper-1";
OutputStream out = new BufferedOutputStream(response.getOutputStream());
InputStream in = new BufferedInputStream(resp.body().byteStream());
// 一次拿多少数据 迭代循环
Expand All @@ -165,7 +164,7 @@ public void outPutAudio(HttpServletResponse response, Response resp, String temM
out.write(buffer, 0, bytesRead);
out.flush();
}
log.info("使用模型:" + model + "," + resp);
log.info("使用edits接口编辑图片, " + resp);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit 084bce6

Please sign in to comment.