diff --git a/mllm-chat b/mllm-chat index c46197fd9..898e94b9c 160000 --- a/mllm-chat +++ b/mllm-chat @@ -1 +1 @@ -Subproject commit c46197fd97c6876be9b99e94fe9fb1f9e5aed571 +Subproject commit 898e94b9c68af2223570a2793cd19d8b6b24008a diff --git a/mllm-cli/cmd/mllm-client/main.go b/mllm-cli/cmd/mllm-client/main.go index 969ab48cf..214f4e128 100644 --- a/mllm-cli/cmd/mllm-client/main.go +++ b/mllm-cli/cmd/mllm-client/main.go @@ -50,10 +50,9 @@ func main() { history = history[:len(history)-1] continue } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() log.Printf("ERROR: Server returned status %s: %s", resp.Status, string(bodyBytes)) history = history[:len(history)-1] continue @@ -83,6 +82,7 @@ func main() { } fmt.Println() if err := scanner.Err(); err != nil { log.Printf("ERROR reading stream: %v", err) } + resp.Body.Close() history = append(history, api.RequestMessage{Role: "assistant", Content: fullResponse.String()}) } } \ No newline at end of file diff --git a/mllm-cli/cmd/mllm-server/main.go b/mllm-cli/cmd/mllm-server/main.go index 819e9936b..b51566228 100644 --- a/mllm-cli/cmd/mllm-server/main.go +++ b/mllm-cli/cmd/mllm-server/main.go @@ -18,6 +18,7 @@ import ( func main() { modelPath := flag.String("model-path", "", "Path to the MLLM model directory.") + probePath := flag.String("probe-path", "", "Path to the probes directory for Qwen3 probing session.") ocrModelPath := flag.String("ocr-model-path", "", "Path to the DeepSeek-OCR model directory.") flag.Parse() @@ -35,7 +36,16 @@ func main() { if *modelPath != "" { log.Printf("Loading Qwen3 model and creating session from: %s", *modelPath) - session, err := mllm.NewSession(*modelPath) + var ( + session *mllm.Session + err error + ) + if *probePath != "" { + log.Printf("Probing enabled. Loading probes from: %s", *probePath) + session, err = mllm.NewProbingSession(*modelPath, *probePath) + } else { + session, err = mllm.NewSession(*modelPath) + } if err != nil { log.Fatalf("FATAL: Failed to create Qwen3 session: %v", err) } @@ -89,4 +99,4 @@ func main() { mllmService.Shutdown() log.Println("Server gracefully stopped.") -} \ No newline at end of file +} diff --git a/mllm-cli/go.mod b/mllm-cli/go.mod index aa7f4742e..12ab32281 100644 --- a/mllm-cli/go.mod +++ b/mllm-cli/go.mod @@ -1,10 +1,16 @@ module mllm-cli -go 1.23.0 +go 1.25.0 -toolchain go1.24.11 +require ( + github.com/charmbracelet/bubbles v0.21.0 + golang.org/x/mobile v0.0.0-20260217195705-b56b3793a9c4 +) -require github.com/charmbracelet/bubbles v0.21.0 +require ( + golang.org/x/mod v0.33.0 // indirect + golang.org/x/tools v0.42.0 // indirect +) require ( github.com/atotto/clipboard v0.1.4 // indirect @@ -27,8 +33,8 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - golang.org/x/sync v0.11.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.41.0 // indirect golang.org/x/term v0.34.0 golang.org/x/text v0.3.8 // indirect ) diff --git a/mllm-cli/go.sum b/mllm-cli/go.sum index 6a95e88fa..7b4cf43e2 100644 --- a/mllm-cli/go.sum +++ b/mllm-cli/go.sum @@ -22,6 +22,8 @@ github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQ github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -49,13 +51,19 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/mobile v0.0.0-20260217195705-b56b3793a9c4 h1:uT3oYo9M38vJa7JpT4kCie2lJwOpoUrx7FvV0H7kXSc= +golang.org/x/mobile v0.0.0-20260217195705-b56b3793a9c4/go.mod h1:4OGHIUSBiIqyFAQDaX1tpY0BVnO20DvNDeATBu8aeFQ= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= diff --git a/mllm-cli/mllm/c.go b/mllm-cli/mllm/c.go index cacaab1d5..c424cfa6e 100644 --- a/mllm-cli/mllm/c.go +++ b/mllm-cli/mllm/c.go @@ -24,10 +24,9 @@ import "unsafe" import "fmt" import "runtime" - type Session struct { - cHandle C.MllmCAny - sessionID string + cHandle C.MllmCAny + sessionID string } func isOk(any C.MllmCAny) bool { @@ -43,103 +42,122 @@ func ShutdownContext() bool { } func StartService(workerThreads int) bool { - result := C.startService(C.size_t(workerThreads)) - return isOk(result) + result := C.startService(C.size_t(workerThreads)) + return isOk(result) } func StopService() bool { - result := C.stopService() - return isOk(result) + result := C.stopService() + return isOk(result) } func SetLogLevel(level int) { - C.setLogLevel(C.int(level)) + C.setLogLevel(C.int(level)) } func NewSession(modelPath string) (*Session, error) { - cModelPath := C.CString(modelPath) - defer C.free(unsafe.Pointer(cModelPath)) + cModelPath := C.CString(modelPath) + defer C.free(unsafe.Pointer(cModelPath)) + + handle := C.createQwen3Session(cModelPath) + if !isOk(handle) { + return nil, fmt.Errorf("底层C API createQwen3Session 失败") + } + s := &Session{cHandle: handle} + runtime.SetFinalizer(s, func(s *Session) { + fmt.Println("[Go Finalizer] Mllm Session automatically released.") + C.freeSession(s.cHandle) + }) + + return s, nil +} + +func NewProbingSession(modelPath string, probePath string) (*Session, error) { + cModelPath := C.CString(modelPath) + defer C.free(unsafe.Pointer(cModelPath)) + cProbePath := C.CString(probePath) + defer C.free(unsafe.Pointer(cProbePath)) - handle := C.createQwen3Session(cModelPath) - if !isOk(handle) { - return nil, fmt.Errorf("底层C API createQwen3Session 失败") - } - s := &Session{cHandle: handle} - runtime.SetFinalizer(s, func(s *Session) { - fmt.Println("[Go Finalizer] Mllm Session automatically released.") - C.freeSession(s.cHandle) - }) + handle := C.createQwen3ProbingSession(cModelPath, cProbePath) + if !isOk(handle) { + return nil, fmt.Errorf("底层C API createQwen3ProbingSession 失败") + } + s := &Session{cHandle: handle} + runtime.SetFinalizer(s, func(s *Session) { + fmt.Println("[Go Finalizer] Mllm Probing Session automatically released.") + C.freeSession(s.cHandle) + }) - return s, nil + return s, nil } func NewDeepseekOCRSession(modelPath string) (*Session, error) { - cModelPath := C.CString(modelPath) - defer C.free(unsafe.Pointer(cModelPath)) + cModelPath := C.CString(modelPath) + defer C.free(unsafe.Pointer(cModelPath)) - handle := C.createDeepseekOCRSession(cModelPath) - if !isOk(handle) { - return nil, fmt.Errorf("底层C API createDeepseekOCRSession 失败") - } - s := &Session{cHandle: handle} - runtime.SetFinalizer(s, func(s *Session) { - fmt.Println("[Go Finalizer] Mllm OCR Session automatically released.") - C.freeSession(s.cHandle) - }) + handle := C.createDeepseekOCRSession(cModelPath) + if !isOk(handle) { + return nil, fmt.Errorf("底层C API createDeepseekOCRSession 失败") + } + s := &Session{cHandle: handle} + runtime.SetFinalizer(s, func(s *Session) { + fmt.Println("[Go Finalizer] Mllm OCR Session automatically released.") + C.freeSession(s.cHandle) + }) - return s, nil + return s, nil } func (s *Session) Close() { - if C.MllmCAny_get_v_custom_ptr(s.cHandle) != nil { - fmt.Println("[Go Close] Mllm Session manually closed.") - C.freeSession(s.cHandle) - s.cHandle = C.MllmCAny_set_v_custom_ptr_null(s.cHandle) - runtime.SetFinalizer(s, nil) - } + if C.MllmCAny_get_v_custom_ptr(s.cHandle) != nil { + fmt.Println("[Go Close] Mllm Session manually closed.") + C.freeSession(s.cHandle) + s.cHandle = C.MllmCAny_set_v_custom_ptr_null(s.cHandle) + runtime.SetFinalizer(s, nil) + } } func (s *Session) Insert(sessionID string) bool { - cSessionID := C.CString(sessionID) - defer C.free(unsafe.Pointer(cSessionID)) - result := C.insertSession(cSessionID, s.cHandle) - if isOk(result) { - s.sessionID = sessionID - } - return isOk(result) + cSessionID := C.CString(sessionID) + defer C.free(unsafe.Pointer(cSessionID)) + result := C.insertSession(cSessionID, s.cHandle) + if isOk(result) { + s.sessionID = sessionID + } + return isOk(result) } func (s *Session) SendRequest(jsonRequest string) bool { - if s.sessionID == "" { - fmt.Println("[Go SendRequest] Error: sessionID is not set on this session.") - return false - } - cSessionID := C.CString(s.sessionID) - cJsonRequest := C.CString(jsonRequest) - defer C.free(unsafe.Pointer(cSessionID)) - defer C.free(unsafe.Pointer(cJsonRequest)) - - result := C.sendRequest(cSessionID, cJsonRequest) - return isOk(result) -} - -func (s *Session) PollResponse(requestID string) string { - if requestID == "" { - fmt.Println("[Go PollResponse] Error: requestID cannot be empty.") - return "" - } - cRequestID := C.CString(requestID) - defer C.free(unsafe.Pointer(cRequestID)) - - cResponse := C.pollResponse(cRequestID) - if cResponse == nil { - return "" - } - defer C.freeResponseString(cResponse) - - return C.GoString(cResponse) + if s.sessionID == "" { + fmt.Println("[Go SendRequest] Error: sessionID is not set on this session.") + return false + } + cSessionID := C.CString(s.sessionID) + cJsonRequest := C.CString(jsonRequest) + defer C.free(unsafe.Pointer(cSessionID)) + defer C.free(unsafe.Pointer(cJsonRequest)) + + result := C.sendRequest(cSessionID, cJsonRequest) + return isOk(result) +} + +func (s *Session) PollResponse(requestID string) string { + if requestID == "" { + fmt.Println("[Go PollResponse] Error: requestID cannot be empty.") + return "" + } + cRequestID := C.CString(requestID) + defer C.free(unsafe.Pointer(cRequestID)) + + cResponse := C.pollResponse(cRequestID) + if cResponse == nil { + return "" + } + defer C.freeResponseString(cResponse) + + return C.GoString(cResponse) } func (s *Session) SessionID() string { - return s.sessionID + return s.sessionID } diff --git a/mllm-cli/mobile_adapter/mobile_server.go b/mllm-cli/mobile_adapter/mobile_server.go new file mode 100644 index 000000000..e3485a7d8 --- /dev/null +++ b/mllm-cli/mobile_adapter/mobile_server.go @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +package gomllm + +import ( +"log" +"os" +"path/filepath" + +_ "golang.org/x/mobile/bind" +"mllm-cli/mllm" +pkgmllm "mllm-cli/pkg/mllm" +"mllm-cli/pkg/server" +) + +func StartServer(modelPath string, ocrPath string, tmpDir string, enableProbing bool) string { +log.Println("[GoMobile] StartServer called") + +if tmpDir != "" { +if err := os.Setenv("TMPDIR", tmpDir); err != nil { +log.Printf("[GoMobile] Error setting TMPDIR: %v", err) +} else { +log.Printf("[GoMobile] TMPDIR set to: %s", tmpDir) +} +} + +if !mllm.InitializeContext() { +return "Error: InitializeContext failed" +} +mllm.SetLogLevel(2) + +service := pkgmllm.NewService() + +if modelPath != "" { +log.Printf("[GoMobile] Loading Qwen: %s", modelPath) +probePath := filepath.Join(modelPath, "probes_linear") +var ( +session *mllm.Session +err error +) +if enableProbing { +if stat, statErr := os.Stat(probePath); statErr == nil && stat.IsDir() { +log.Printf("[GoMobile] Probing enabled, probes found: %s", probePath) +session, err = mllm.NewProbingSession(modelPath, probePath) +} else { +log.Printf("[GoMobile] Probes not found, fallback to normal Qwen session. expected=%s", probePath) +session, err = mllm.NewSession(modelPath) +} +} else { +log.Printf("[GoMobile] Probing disabled. Using normal Qwen session.") +session, err = mllm.NewSession(modelPath) +} +if err != nil { +return "Error: Qwen load failed: " + err.Error() +} +sessionID := filepath.Base(modelPath) +session.Insert(sessionID) +service.RegisterSession(sessionID, session) +} + +if ocrPath != "" { +log.Printf("[GoMobile] Loading OCR: %s", ocrPath) +session, err := mllm.NewDeepseekOCRSession(ocrPath) +if err != nil { +return "Error: OCR load failed: " + err.Error() +} +sessionID := filepath.Base(ocrPath) +session.Insert(sessionID) +service.RegisterSession(sessionID, session) +} + +if !mllm.StartService(1) { +return "Error: StartService failed" +} + +go func() { +s := server.NewServer("127.0.0.1:8080", service) +log.Println("[GoMobile] HTTP Server listening on 8080") +s.Start() +}() + +return "Success: Server Running on 127.0.0.1:8080" +} diff --git a/mllm/c_api/Runtime.cpp b/mllm/c_api/Runtime.cpp index e0b301ced..92809afc0 100644 --- a/mllm/c_api/Runtime.cpp +++ b/mllm/c_api/Runtime.cpp @@ -4,15 +4,16 @@ #include "mllm/c_api/Runtime.h" #include "mllm/engine/service/Service.hpp" #include "mllm/models/qwen3/modeling_qwen3_service.hpp" +#include "mllm/models/qwen3/modeling_qwen3_probing_service.hpp" #include "mllm/models/deepseek_ocr/modeling_deepseek_ocr_service.hpp" #include #include #include #include -#include // for strncpy +#include // for strncpy struct MllmSessionWrapper { - std::shared_ptr session_ptr; + std::shared_ptr session_ptr; }; //===----------------------------------------------------------------------===// @@ -34,14 +35,11 @@ MllmCAny memoryReport() { } int32_t isOk(MllmCAny ret) { - if (ret.type_id == kRetCode && ret.v_return_code == 0) - return true; - if (ret.type_id == kCustomObject && ret.v_custom_ptr != nullptr) - return true; + if (ret.type_id == kRetCode && ret.v_return_code == 0) return true; + if (ret.type_id == kCustomObject && ret.v_custom_ptr != nullptr) return true; return false; } - //===----------------------------------------------------------------------===// // Mllm wrapper functions //===----------------------------------------------------------------------===// @@ -64,132 +62,134 @@ MllmCAny convert2Float(double v) { return MllmCAny{.type_id = kFloat, .v_fp64 = //===----------------------------------------------------------------------===// MllmCAny startService(size_t worker_threads) { - mllm::service::startService(worker_threads); - return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; + mllm::service::startService(worker_threads); + return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; } MllmCAny stopService() { - mllm::service::stopService(); - return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; + mllm::service::stopService(); + return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; } -void setLogLevel(int level) { - mllm::setLogLevel(static_cast(level)); -} +void setLogLevel(int level) { mllm::setLogLevel(static_cast(level)); } MllmCAny createQwen3Session(const char* model_path) { - if (model_path == nullptr) { - printf("[C++ Service] createQwen3Session error: invalid arguments.\n"); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } - try { - auto qwen3_session = std::make_shared(); - qwen3_session->fromPreTrain(model_path); - - auto* handle = new MllmSessionWrapper(); - handle->session_ptr = qwen3_session; - - return MllmCAny{.type_id = kCustomObject, .v_custom_ptr = handle}; - } catch (const std::exception& e) { - printf("[C++ Service] createQwen3Session exception: %s\n", e.what()); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } + if (model_path == nullptr) { + printf("[C++ Service] createQwen3Session error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + try { + auto qwen3_session = std::make_shared(); + qwen3_session->fromPreTrain(model_path); + + auto* handle = new MllmSessionWrapper(); + handle->session_ptr = qwen3_session; + + return MllmCAny{.type_id = kCustomObject, .v_custom_ptr = handle}; + } catch (const std::exception& e) { + printf("[C++ Service] createQwen3Session exception: %s\n", e.what()); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } +} + +MllmCAny createQwen3ProbingSession(const char* model_path, const char* probe_path) { + if (model_path == nullptr || probe_path == nullptr) { + printf("[C++ Service] createQwen3ProbingSession error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + try { + auto qwen3_session = std::make_shared(); + qwen3_session->fromPreTrain(model_path); + + mllm::models::qwen3_probing::ProbingArgs p_args; + p_args.enable_prefill_check = true; + p_args.prefill_stop_threshold = 0.7f; + p_args.default_prefill_layers = {27, 30}; + p_args.enable_decode_check = true; + p_args.decode_stop_threshold = 0.8f; + p_args.pos_threshold = 0.9f; + + qwen3_session->setProbingArgs(p_args); + qwen3_session->loadProbes(probe_path, p_args); + + auto* handle = new MllmSessionWrapper(); + handle->session_ptr = qwen3_session; + + return MllmCAny{.type_id = kCustomObject, .v_custom_ptr = handle}; + } catch (const std::exception& e) { + printf("[C++ Service] createQwen3ProbingSession exception: %s\n", e.what()); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } } MllmCAny createDeepseekOCRSession(const char* model_path) { - if (model_path == nullptr) { - printf("[C++ Service] createDeepseekOCRSession error: invalid arguments.\n"); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } - try { - auto dpsk_session = std::make_shared(); - dpsk_session->fromPreTrain(model_path); - - auto* handle = new MllmSessionWrapper(); - handle->session_ptr = dpsk_session; - - return MllmCAny{.type_id = kCustomObject, .v_custom_ptr = handle}; - } catch (const std::exception& e) { - printf("[C++ Service] createDeepseekOCRSession exception: %s\n", e.what()); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } + if (model_path == nullptr) { + printf("[C++ Service] createDeepseekOCRSession error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + try { + auto dpsk_session = std::make_shared(); + dpsk_session->fromPreTrain(model_path); + + auto* handle = new MllmSessionWrapper(); + handle->session_ptr = dpsk_session; + + return MllmCAny{.type_id = kCustomObject, .v_custom_ptr = handle}; + } catch (const std::exception& e) { + printf("[C++ Service] createDeepseekOCRSession exception: %s\n", e.what()); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } } MllmCAny insertSession(const char* session_id, MllmCAny handle) { - if (session_id == nullptr || handle.type_id != kCustomObject || handle.v_custom_ptr == nullptr) { - printf("[C++ Service] insertSession error: invalid arguments.\n"); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } - - auto* session_wrapper = reinterpret_cast(handle.v_custom_ptr); - mllm::service::insertSession(std::string(session_id), session_wrapper->session_ptr); - return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; + if (session_id == nullptr || handle.type_id != kCustomObject || handle.v_custom_ptr == nullptr) { + printf("[C++ Service] insertSession error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + + auto* session_wrapper = reinterpret_cast(handle.v_custom_ptr); + mllm::service::insertSession(std::string(session_id), session_wrapper->session_ptr); + return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; } MllmCAny freeSession(MllmCAny handle) { - if (handle.type_id != kCustomObject || handle.v_custom_ptr == nullptr) { - printf("[C++ Service] freeSession error: invalid arguments.\n"); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } - - auto* session_wrapper = reinterpret_cast(handle.v_custom_ptr); - delete session_wrapper; - return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; -} + if (handle.type_id != kCustomObject || handle.v_custom_ptr == nullptr) { + printf("[C++ Service] freeSession error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + auto* session_wrapper = reinterpret_cast(handle.v_custom_ptr); + delete session_wrapper; + return MllmCAny{.type_id = kRetCode, .v_return_code = 0}; +} MllmCAny sendRequest(const char* session_id, const char* json_request) { - if (session_id == nullptr || json_request == nullptr) { - printf("[C++ Service] sendRequest error: invalid arguments.\n"); - return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; - } - int status = mllm::service::sendRequest(std::string(json_request)); - return MllmCAny{.type_id = kRetCode, .v_return_code = status}; + if (session_id == nullptr || json_request == nullptr) { + printf("[C++ Service] sendRequest error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + int status = mllm::service::sendRequest(std::string(json_request)); + return MllmCAny{.type_id = kRetCode, .v_return_code = status}; } const char* pollResponse(const char* session_id) { - if (session_id == nullptr) { - return nullptr; - } - - std::string request_id = std::string(session_id); - mllm::service::Response response = mllm::service::getResponse(request_id); - - if (response.empty()) { - return nullptr; - } - - bool finished = false; - try { - nlohmann::json j = nlohmann::json::parse(response); - - - if (j.contains("choices")) { - if (j["choices"].is_array() && !j["choices"].empty()) { - const auto& first_choice = j["choices"][0]; - if (first_choice.contains("finish_reason") && first_choice["finish_reason"] == "stop") { - finished = true; - } - } - } - - } catch (const nlohmann::json::parse_error& e) { - printf("[C++ Service] pollResponse JSON parse error: %s\n", e.what()); - return nullptr; - } - - if (finished) { - return nullptr; - } - - char* c_response = new char[response.length() + 1]; - strncpy(c_response, response.c_str(), response.length() + 1); - - return c_response; + if (session_id == nullptr) { return nullptr; } + + std::string request_id = std::string(session_id); + mllm::service::Response response = mllm::service::getResponse(request_id); + + if (response.empty()) { return nullptr; } + + // Always return the chunk to upper layers, including the final chunk with + // finish_reason="stop". The HTTP layer already handles stop detection and + // stream termination; filtering it here can drop valid assistant output. + + char* c_response = new char[response.length() + 1]; + strncpy(c_response, response.c_str(), response.length() + 1); + + return c_response; } void freeResponseString(const char* response_str) { - if (response_str != nullptr) { - delete[] response_str; - } + if (response_str != nullptr) { delete[] response_str; } } \ No newline at end of file diff --git a/mllm/c_api/Runtime.h b/mllm/c_api/Runtime.h index 97b34f4e8..dfff60278 100644 --- a/mllm/c_api/Runtime.h +++ b/mllm/c_api/Runtime.h @@ -42,6 +42,8 @@ void setLogLevel(int level); MllmCAny createQwen3Session(const char* model_path); +MllmCAny createQwen3ProbingSession(const char* model_path, const char* probe_path); + MllmCAny createDeepseekOCRSession(const char* model_path); MllmCAny insertSession(const char* session_id, MllmCAny handle); diff --git a/mllm/engine/service/Service.cpp b/mllm/engine/service/Service.cpp index 5400bcdb2..5f565d693 100644 --- a/mllm/engine/service/Service.cpp +++ b/mllm/engine/service/Service.cpp @@ -47,7 +47,7 @@ void ResponsePool::push(const std::string& req_id, ResponseItem item) { std::unique_lock lk(mtx_); auto& q = queues_[req_id]; q.push(std::move(item)); - cv_.notify_one(); + cv_.notify_all(); } std::optional ResponsePool::pop(const std::string& req_id) { diff --git a/mllm/models/qwen3/modeling_qwen3_probing_service.hpp b/mllm/models/qwen3/modeling_qwen3_probing_service.hpp index f4b3ceaf7..cd62686bd 100644 --- a/mllm/models/qwen3/modeling_qwen3_probing_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_probing_service.hpp @@ -159,9 +159,7 @@ class ProbeClassifier : public Module { auto logits = linear_(x); float val = 0.0f; - if (logits.dtype() == mllm::kFloat32) { - val = logits.ptr()[0]; - } + if (logits.dtype() == mllm::kFloat32) { val = logits.ptr()[0]; } return 1.0f / (1.0f + std::exp(-val)); } @@ -590,6 +588,30 @@ class Qwen3ProbingSession final : public ::mllm::service::Session { mllm::cpu::wakeupHpcThreadPool(); auto messages = request["messages"]; + // Keep probing-chat behavior aligned with main_probing: + // if a previous turn ended with hallucination early-exit, drop that + // assistant payload and its paired user turn from incoming history. + if (messages.is_array()) { + nlohmann::json sanitized_messages = nlohmann::json::array(); + for (const auto& msg : messages) { + if (!msg.is_object()) continue; + + const auto role = msg.value("role", ""); + std::string content; + if (msg.contains("content") && msg["content"].is_string()) { content = msg["content"].get(); } + + if (role == "assistant" && content.find("early_exit_hallucination") != std::string::npos) { + if (!sanitized_messages.empty() && sanitized_messages.back().value("role", "") == "user") { + sanitized_messages.erase(sanitized_messages.end() - 1); + } + continue; + } + + sanitized_messages.push_back(msg); + } + messages = std::move(sanitized_messages); + } + // 简短指令 std::string concise_instruction = " Please answer in a single, complete sentence. Keep it concise.";