diff --git a/src/main/java/com/unfbx/chatgpt/AzureOpenAiApi.java b/src/main/java/com/unfbx/chatgpt/AzureOpenAiApi.java new file mode 100644 index 0000000..9194287 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/AzureOpenAiApi.java @@ -0,0 +1,81 @@ +package com.unfbx.chatgpt; + +import com.unfbx.chatgpt.entity.chat.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; +import com.unfbx.chatgpt.entity.completions.Completion; +import com.unfbx.chatgpt.entity.completions.CompletionResponse; +import com.unfbx.chatgpt.entity.embeddings.Embedding; +import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse; +import io.reactivex.Single; +import retrofit2.http.Body; +import retrofit2.http.POST; + +/** + * Azure OpenAI api接口 + * api版本:2023-03-15-preview + * + * apiHost: + * https://${your-resource-name}.openai.azure.com/openai/deployments/${deployment-id}/ + * + * 文档: + * https://learn.microsoft.com/zh-cn/azure/cognitive-services/openai/reference + * swagger:https://github.com/Azure/azure-rest-api-specs/blob/main/specification/cognitiveservices/data-plane/AzureOpenAI/inference/stable/2022-12-01/inference.json + * + * @author skywalker + * @since 2023/5/7 17:22 + */ +public interface AzureOpenAiApi extends OpenAiApi { + + /** + * 与OpenAiApi接口保持一直,不额外增加参数传递api-version字段 + */ + + + /** + * 文本问答 + * Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. + * + * 注意: + * logprobs, best_of and echo parameters are not available on gpt-35-turbo model. + * azure版本api,在gpt-35-turbo model下,不支持传递logprobs, best_of, echo这3个参数,需要置为null + * + * 示例: + * Completion q = Completion.builder() + * .prompt("who are you?") + * .logprobs(null) + * .bestOf(null) + * .echo(null) + * .maxTokens(16) + * .build(); + * + * @param completion 问答参数 + * @return Single CompletionResponse + */ + @POST("completions" ) + Single completions(@Body Completion completion); + + /** + * 文本向量计算 + * + * 注意: + * Too many inputs for model None. The max number of inputs is 1. We hope to increase the number of inputs per request soon. + * Azure版本api只支持传递一个input + * + * @param embedding 向量参数 + * @return Single EmbeddingResponse + */ + @POST("embeddings" ) + Single embeddings(@Body Embedding embedding); + + /** + * 最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型 + * + * @param chatCompletion chat completion + * @return 返回答案 + */ + @Override + @POST("chat/completions" ) + Single chatCompletion(@Body ChatCompletion chatCompletion); + + +} diff --git a/src/main/java/com/unfbx/chatgpt/OpenAiAzureStreamClient.java b/src/main/java/com/unfbx/chatgpt/OpenAiAzureStreamClient.java new file mode 100644 index 0000000..cb50e26 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/OpenAiAzureStreamClient.java @@ -0,0 +1,338 @@ +package com.unfbx.chatgpt; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.util.ObjectUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.http.ContentType; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.unfbx.chatgpt.constant.OpenAIConst; +import com.unfbx.chatgpt.entity.chat.BaseChatCompletion; +import com.unfbx.chatgpt.entity.completions.Completion; +import com.unfbx.chatgpt.entity.completions.CompletionResponse; +import com.unfbx.chatgpt.exception.BaseException; +import com.unfbx.chatgpt.exception.CommonError; +import com.unfbx.chatgpt.interceptor.AzureOpenAiApiVersionInterceptor; +import com.unfbx.chatgpt.interceptor.AzureOpenAiAuthInterceptor; +import com.unfbx.chatgpt.interceptor.DefaultOpenAiAuthInterceptor; +import com.unfbx.chatgpt.interceptor.DynamicKeyOpenAiAuthInterceptor; +import com.unfbx.chatgpt.sse.ConsoleEventSourceListener; +import io.reactivex.Single; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; +import org.jetbrains.annotations.NotNull; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.util.*; +import java.util.concurrent.TimeUnit; + + +/** + * 描述: open ai 客户端 + * + * @author https:www.unfbx.com + * 2023-02-28 + */ + +@Slf4j +public class OpenAiAzureStreamClient { + @Getter + @NotNull + private List apiKey; + /** + * 自定义api host使用builder的方式构造client + */ + @Getter + private List apiHost; + + @Getter + private Integer apiHostIndex; + + @Getter + private Map cacheMap ; + + @Getter + private String apiVerison; + + +// @Getter +// private Map auditCacheMap ; +// +// @Getter +// private Map timeCacheMap ; +// +// +// @Getter +// private Long exceededTokenlimit; +// +// @Getter +// private Long auditToken; + + /** + * 自定义的okHttpClient + * 如果不自定义 ,就是用sdk默认的OkHttpClient实例 + */ + @Getter + private OkHttpClient okHttpClient; + + /** + * api key的获取策略 + */ + + @Getter + private List azureOpenAiApis; + + /** + * 自定义鉴权处理拦截器
+ * 可以不设置,默认实现:DefaultOpenAiAuthInterceptor
+ * 如需自定义实现参考:DealKeyWithOpenAiAuthInterceptor + * + * @see DynamicKeyOpenAiAuthInterceptor + * @see DefaultOpenAiAuthInterceptor + */ + @Getter + private AzureOpenAiAuthInterceptor azureOpenAiAuthInterceptor; + + @Getter + private AzureOpenAiApiVersionInterceptor apiVersionInterceptor; + + /** + * 构造实例对象 + * + * @param builder + */ + private OpenAiAzureStreamClient(Builder builder) { + if (CollectionUtil.isEmpty(builder.apiKey)) { + throw new BaseException(CommonError.API_KEYS_NOT_NUL); + } + apiKey = builder.apiKey; + + if (CollectionUtil.isEmpty(builder.apiHost)) { + builder.apiHost = Arrays.asList(OpenAIConst.OPENAI_HOST); + } + apiHost = builder.apiHost; + apiHostIndex = 0; + cacheMap = new HashMap(); +// auditCacheMap = new HashMap<>(); +// timeCacheMap = new HashMap<>(); + cacheMap.put("cacheIndex", apiHostIndex); + + + if(StrUtil.isEmpty(builder.apiVerison)){ + builder.apiVerison = "2023-07-01-preview"; + } + apiVerison = builder.apiVerison; + + builder.azureOpenAiAuthInterceptor = new AzureOpenAiAuthInterceptor(); + azureOpenAiAuthInterceptor = builder.azureOpenAiAuthInterceptor; + //设置apiKeys和key的获取策略 + azureOpenAiAuthInterceptor.setApiKeys(this.apiKey); + azureOpenAiAuthInterceptor.setCacheMap(this.cacheMap); + + apiVersionInterceptor = new AzureOpenAiApiVersionInterceptor(); + apiVersionInterceptor.setApiVersion(this.apiVerison); + + +// if(ObjectUtil.isEmpty(builder.exceededTokenlimit)){ +// builder.exceededTokenlimit = 0l; +// } +// exceededTokenlimit = builder.exceededTokenlimit; +// +// if(ObjectUtil.isEmpty(builder.auditToken)){ +// builder.auditToken = 8000l; +// } +// auditToken = builder.auditToken; + + if (Objects.isNull(builder.okHttpClient)) { + builder.okHttpClient = this.okHttpClient(); + } else { + builder.okHttpClient = builder.okHttpClient + .newBuilder() + .addInterceptor(apiVersionInterceptor) + .addInterceptor(azureOpenAiAuthInterceptor) + .build(); + } + + okHttpClient = builder.okHttpClient; + + + this.azureOpenAiApis = new ArrayList<>(); + for (int i = 0;i< apiHost.size();i++) { + Retrofit.Builder builderx = new Retrofit.Builder() + .baseUrl(apiHost.get(i)) + .client(okHttpClient) + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .addConverterFactory(JacksonConverterFactory.create()); + azureOpenAiApis.add(builderx.build().create(AzureOpenAiApi.class)); + } + } + + /** + * 创建默认的OkHttpClient + */ + private OkHttpClient okHttpClient() { + if (Objects.isNull(this.azureOpenAiAuthInterceptor)) { + this.azureOpenAiAuthInterceptor = new AzureOpenAiAuthInterceptor(); + } + this.azureOpenAiAuthInterceptor.setApiKeys(this.apiKey); + this.azureOpenAiAuthInterceptor.setCacheMap(this.cacheMap); + this.apiVersionInterceptor.setApiVersion(this.apiVerison); + + return new OkHttpClient + .Builder() + .addInterceptor(this.apiVersionInterceptor) + .addInterceptor(this.azureOpenAiAuthInterceptor) + .connectTimeout(20, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); + } + //这个是给rifet的 + public CompletionResponse completions(String question) { + Completion q = Completion.builder() + .prompt(question) + .build(); + Single completions = this.azureOpenAiApis.get(apiHostIndex).completions(q); + return completions.blockingGet(); + } + + /** + * 流式输出,最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型 + * + * @param chatCompletion 问答参数 + * @param eventSourceListener sse监听器 + * @see ConsoleEventSourceListener + */ + public void streamAzureChatCompletion(T chatCompletion, EventSourceListener eventSourceListener,String tagid) { + if (Objects.isNull(eventSourceListener)) { + log.error("参数异常:EventSourceListener不能为空,可以参考:com.unfbx.chatgpt.sse.ConsoleEventSourceListener"); + throw new BaseException(CommonError.PARAM_ERROR); + } + if (!chatCompletion.isStream()) { + chatCompletion.setStream(true); + } + try { + EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); + ObjectMapper mapper = new ObjectMapper(); + String requestBody = mapper.writeValueAsString(chatCompletion); + + Request request = new Request.Builder() + .tag(tagid) + .url(getApiHostUrl() + "chat/completions") + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + //创建事件 + + EventSource eventSource = factory.newEventSource(request, eventSourceListener); + } catch (JsonProcessingException e) { + log.error("请求参数解析异常:{}", e); + e.printStackTrace(); + } catch (Exception e) { + log.error("请求参数解析异常:{}", e); + e.printStackTrace(); + } + } + + private String getApiHostUrl() { + SequentialKey(); + return this.apiHost.get(this.apiHostIndex); + } + private Integer SequentialKey() { + //当前index值+1使用如果当前值已经超过最大值就从头开始 + if(this.apiHostIndex apiKey; + /** + * api请求地址,结尾处有斜杠 + * + * @see OpenAIConst + */ + private List apiHost; + +// private Long exceededTokenlimit; +// +// private Long auditToken; + + private String apiVerison; + + + private OkHttpClient okHttpClient; + + /** + * 自定义鉴权拦截器 + */ + private AzureOpenAiAuthInterceptor azureOpenAiAuthInterceptor; + + public Builder() { + } + + + public Builder apiKey(@NotNull List val) { + apiKey = val; + return this; + } + + /** + * @param val api请求地址,结尾处有斜杠 + * @return Builder + * @see OpenAIConst + */ + public Builder apiHost(@NotNull List val) { + apiHost = val; + return this; + } + + + public Builder apiVerison(@NotNull String val) { + apiVerison = val; + return this; + } + +// public Builder exceededTokenlimit(Long val) { +// exceededTokenlimit = val; +// return this; +// } +// public Builder auditToken(Long val) { +// auditToken = val; +// return this; +// } + public Builder okHttpClient(OkHttpClient val) { + okHttpClient = val; + return this; + } + + public Builder authInterceptor(AzureOpenAiAuthInterceptor val) { + azureOpenAiAuthInterceptor = val; + return this; + } + + public OpenAiAzureStreamClient build() { + return new OpenAiAzureStreamClient(this); + } + + + } +} diff --git a/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiApiVersionInterceptor.java b/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiApiVersionInterceptor.java new file mode 100644 index 0000000..b682ab0 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiApiVersionInterceptor.java @@ -0,0 +1,35 @@ +package com.unfbx.chatgpt.interceptor; + +import lombok.Getter; +import lombok.Setter; +import okhttp3.HttpUrl; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; + +/** + * @author skywalker + * @since 2023/5/7 17:22 + */ +public class AzureOpenAiApiVersionInterceptor implements Interceptor { + + + @Getter + @Setter + private String apiVersion; + + @Override + public Response intercept(Chain chain) throws IOException { + Request original = chain.request(); + HttpUrl originalHttpUrl = original.url(); + HttpUrl url = originalHttpUrl.newBuilder() + .addQueryParameter("api-version", apiVersion) + .build(); + Request.Builder requestBuilder = original.newBuilder() + .url(url); + Request request = requestBuilder.build(); + return chain.proceed(request); + } +} diff --git a/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiAuthInterceptor.java b/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiAuthInterceptor.java new file mode 100644 index 0000000..0abf248 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiAuthInterceptor.java @@ -0,0 +1,55 @@ +package com.unfbx.chatgpt.interceptor; + +import cn.hutool.http.ContentType; +import cn.hutool.http.Header; +import lombok.Getter; +import lombok.Setter; +import okhttp3.HttpUrl; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @author skywalker + * @since 2023/5/7 17:22 + */ +public class AzureOpenAiAuthInterceptor implements Interceptor { + + + @Getter + @Setter + private List apiKeys; + + @Getter + @Setter + private Map cacheMap; + + + public Request auth(String key, Request original) { + + Request request = original.newBuilder() + .header("api-key",key) + .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) + .method(original.method(), original.body()) + .build(); + return request; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request original = chain.request(); + return chain.proceed(auth(this.getKey(), original)); + } + + private String getKey() { + Integer count = this.cacheMap.get("cacheIndex"); + return this.apiKeys.get(count.intValue()); + } + + +} diff --git a/src/test/java/com/unfbx/chatgpt/OpenAzureAiClientTest.java b/src/test/java/com/unfbx/chatgpt/OpenAzureAiClientTest.java new file mode 100644 index 0000000..224f5bf --- /dev/null +++ b/src/test/java/com/unfbx/chatgpt/OpenAzureAiClientTest.java @@ -0,0 +1,86 @@ +package com.unfbx.chatgpt; + +import com.unfbx.chatgpt.entity.chat.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.completions.CompletionResponse; +import com.unfbx.chatgpt.interceptor.OpenAILogger; +import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor; +import com.unfbx.chatgpt.sse.ConsoleEventSourceListener; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.logging.HttpLoggingInterceptor; +import org.junit.Before; +import org.junit.Test; + +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * 描述: 测试类 + * + * @author Nixer + * 2023-02-11 + */ +@Slf4j +public class OpenAzureAiClientTest { + + private OpenAiAzureStreamClient v3; + + @Before + public void before() { + //可以为null +// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)); + HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); + //!!!!千万别再生产或者测试环境打开BODY级别日志!!!! + //!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!! + httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); + List apikeys =Arrays.asList("ef*********************************"); + List apiHost = Arrays.asList("https://gptservice01.openai.azure"); + OkHttpClient okHttpClient = new OkHttpClient + .Builder() +// .proxy(proxy) + .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new OpenAiResponseInterceptor()) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .readTimeout(30, TimeUnit.SECONDS) + .build(); + v3 = OpenAiAzureStreamClient.builder() + .apiKey(apikeys) + .okHttpClient(okHttpClient) + .apiHost(apiHost) + .build(); + } + + @Test + public void chatCompletions() { + ConsoleEventSourceListener eventSourceListener = new ConsoleEventSourceListener(); + Message message = Message.builder().role(Message.Role.USER).content("random one word!").build(); + ChatCompletion chatCompletion = ChatCompletion + .builder() + .model(ChatCompletion.Model.GPT_3_5_TURBO.getName()) + .temperature(0.2) + .maxTokens(2048) + .messages(Collections.singletonList(message)) + .stream(true) + .build(); + for (int i = 0; i < 2; i++) { + v3.streamAzureChatCompletion(chatCompletion, eventSourceListener,null); + CountDownLatch countDownLatch = new CountDownLatch(1); + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + + + +/* @Test + public void completions() { + CompletionResponse completions = v3.completions("什么?"); + (completions.getChoices()).forEach(System.out::println); + }*/ +}