diff --git a/go/plugins/ollama/model.go b/go/plugins/ollama/model.go new file mode 100644 index 0000000000..ee88a747e0 --- /dev/null +++ b/go/plugins/ollama/model.go @@ -0,0 +1,160 @@ +package ollama + +import ( + "errors" + + "github.com/firebase/genkit/go/ai" +) + +var topLevelOpts = map[string]struct{}{ + "think": {}, + "keep_alive": {}, +} + +// Ollama has two API endpoints, one with a chat interface and another with a generate response interface. +// That's why have multiple request interfaces for the Ollama API below. + +/* +TODO: Support optional, advanced parameters: +format: the format to return a response in. Currently the only accepted value is json +options: additional model parameters listed in the documentation for the Modelfile such as temperature +system: system message to (overrides what is defined in the Modelfile) +template: the prompt template to use (overrides what is defined in the Modelfile) +context: the context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory +stream: if false the response will be returned as a single response object, rather than a stream of objects +raw: if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API +*/ +type ollamaChatRequest struct { + Messages []*ollamaMessage `json:"messages"` + Images []string `json:"images,omitempty"` + Model string `json:"model"` + Stream bool `json:"stream"` + Format string `json:"format,omitempty"` + Tools []ollamaTool `json:"tools,omitempty"` + Think any `json:"think,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive string `json:"keep_alive,omitempty"` +} + +func (o *ollamaChatRequest) ApplyOptions(cfg any) error { + if cfg == nil { + return nil + } + + switch cfg := cfg.(type) { + case GenerateContentConfig: + o.applyGenerateContentConfig(&cfg) + return nil + case *GenerateContentConfig: + o.applyGenerateContentConfig(cfg) + return nil + case map[string]any: + return o.applyMapAny(cfg) + case *ai.GenerationCommonConfig: + return o.applyGenerationCommonConfig(cfg) + case ai.GenerationCommonConfig: + return o.applyGenerationCommonConfig(&cfg) + default: + return errors.New("unknown generation config") + } +} +func (o *ollamaChatRequest) applyGenerateContentConfig(cfg *GenerateContentConfig) { + if cfg == nil { + return + } + + // thinking + if cfg.Think != nil { + o.Think = cfg.Think + } + + // runtime options + opts := map[string]any{} + + if cfg.Seed != nil { + opts["seed"] = *cfg.Seed + } + if cfg.Temperature != nil { + opts["temperature"] = *cfg.Temperature + } + if cfg.TopK != nil { + opts["top_k"] = *cfg.TopK + } + if cfg.TopP != nil { + opts["top_p"] = *cfg.TopP + } + if cfg.MinP != nil { + opts["min_p"] = *cfg.MinP + } + if len(cfg.Stop) > 0 { + opts["stop"] = cfg.Stop + } + if cfg.NumCtx != nil { + opts["num_ctx"] = *cfg.NumCtx + } + if cfg.NumPredict != nil { + opts["num_predict"] = *cfg.NumPredict + } + + if len(opts) > 0 { + o.Options = opts + } +} +func (o *ollamaChatRequest) applyGenerationCommonConfig(cfg *ai.GenerationCommonConfig) error { + if cfg == nil { + return nil + } + + opts := map[string]any{} + + if cfg.MaxOutputTokens > 0 { + opts["num_predict"] = cfg.MaxOutputTokens + } + if len(cfg.StopSequences) > 0 { + opts["stop"] = cfg.StopSequences + } + if cfg.Temperature != 0 { + opts["temperature"] = cfg.Temperature + } + if cfg.TopK > 0 { + opts["top_k"] = cfg.TopK + } + if cfg.TopP > 0 { + opts["top_p"] = cfg.TopP + } + + if len(opts) > 0 { + o.Options = opts + } + + return nil +} + +func (o *ollamaChatRequest) applyMapAny(m map[string]any) error { + if len(m) == 0 { + return nil + } + opts := map[string]any{} + for k, v := range m { + if _, isTopLevel := topLevelOpts[k]; isTopLevel { + switch k { + case "think": + o.Think = v + case "keep_alive": + if s, ok := v.(string); ok { + o.KeepAlive = s + } else { + return errors.New("keep_alive must be string") + } + } + continue + } + opts[k] = v + } + + if len(opts) > 0 { + o.Options = opts + } + + return nil +} diff --git a/go/plugins/ollama/model_test.go b/go/plugins/ollama/model_test.go new file mode 100644 index 0000000000..a18df09770 --- /dev/null +++ b/go/plugins/ollama/model_test.go @@ -0,0 +1,129 @@ +package ollama + +import ( + "reflect" + "testing" + + "github.com/firebase/genkit/go/ai" +) + +func TestOllamaChatRequest_ApplyOptions(t *testing.T) { + seed := 42 + temp := 0.7 + + tests := []struct { + name string + cfg any + want *ollamaChatRequest + wantErr bool + }{ + { + name: "GenerateContentConfig pointer", + cfg: &GenerateContentConfig{ + Seed: &seed, + Temperature: &temp, + Think: true, + }, + want: &ollamaChatRequest{ + Think: true, + Options: map[string]any{ + "seed": seed, + "temperature": temp, + }, + }, + }, + { + name: "GenerateContentConfig value", + cfg: GenerateContentConfig{ + Seed: &seed, + Think: true, + }, + want: &ollamaChatRequest{ + Think: true, + Options: map[string]any{ + "seed": seed, + }, + }, + }, + { + name: "map[string]any with opts only", + cfg: map[string]any{ + "temperature": 0.5, + "top_k": 40, + }, + want: &ollamaChatRequest{ + Options: map[string]any{ + "temperature": 0.5, + "top_k": 40, + }, + }, + }, + { + name: "map[string]any with top level fields", + cfg: map[string]any{ + "think": true, + "keep_alive": "10m", + }, + want: &ollamaChatRequest{ + Think: true, + KeepAlive: "10m", + }, + }, + { + name: "map[string]any mixed main and opts", + cfg: map[string]any{ + "temperature": 0.9, + "think": true, + }, + want: &ollamaChatRequest{ + Think: true, + Options: map[string]any{ + "temperature": 0.9, + }, + }, + }, + { + name: "GenerationCommonConfig pointer", + cfg: &ai.GenerationCommonConfig{ + Temperature: temp, + }, + want: &ollamaChatRequest{ + Options: map[string]any{ + "temperature": temp, + }, + }, + }, + { + name: "nil config", + cfg: nil, + want: &ollamaChatRequest{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &ollamaChatRequest{} + + err := req.ApplyOptions(tt.cfg) + + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(req, tt.want) { + t.Errorf( + "unexpected result:\nwant: %#v\n got: %#v", + tt.want, + req, + ) + } + }) + } +} diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index 4eb4469673..163ade9ed1 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -119,29 +119,27 @@ type ollamaMessage struct { Content string `json:"content,omitempty"` Images []string `json:"images,omitempty"` ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` + Thinking string `json:"thinking,omitempty"` } -// Ollama has two API endpoints, one with a chat interface and another with a generate response interface. -// That's why have multiple request interfaces for the Ollama API below. - -/* -TODO: Support optional, advanced parameters: -format: the format to return a response in. Currently the only accepted value is json -options: additional model parameters listed in the documentation for the Modelfile such as temperature -system: system message to (overrides what is defined in the Modelfile) -template: the prompt template to use (overrides what is defined in the Modelfile) -context: the context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory -stream: if false the response will be returned as a single response object, rather than a stream of objects -raw: if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API -keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m) -*/ -type ollamaChatRequest struct { - Messages []*ollamaMessage `json:"messages"` - Images []string `json:"images,omitempty"` - Model string `json:"model"` - Stream bool `json:"stream"` - Format string `json:"format,omitempty"` - Tools []ollamaTool `json:"tools,omitempty"` +type GenerateContentConfig struct { + // Thinking mode: + // ollama: true | false + // gpt-oss: "low" | "medium" | "high" + Think any + + // Runtime options + Seed *int + Temperature *float64 + TopK *int + TopP *float64 + MinP *float64 + Stop []string + NumCtx *int + NumPredict *int + + // Ollama-specific + KeepAlive string } type ollamaModelRequest struct { @@ -184,6 +182,7 @@ type ollamaChatResponse struct { Message struct { Role string `json:"role"` Content string `json:"content"` + Thinking string `json:"thinking"` ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` } `json:"message"` } @@ -253,6 +252,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun Images: images, Stream: stream, } + } else { var messages []*ollamaMessage // Translate all messages to ollama message format. @@ -263,12 +263,17 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } messages = append(messages, message) } + chatReq := ollamaChatRequest{ Messages: messages, Model: g.model.Name, Stream: stream, Images: images, } + if err := chatReq.ApplyOptions(input.Config); err != nil { + return nil, fmt.Errorf("failed to apply options: %v", err) + } + if len(input.Tools) > 0 { tools, err := convertTools(input.Tools) if err != nil { @@ -417,6 +422,8 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { return nil, fmt.Errorf("failed to marshal tool response: %v", err) } contentBuilder.WriteString(string(outputJSON)) + } else if part.IsReasoning() { + contentBuilder.WriteString(part.Text) } else { return nil, errors.New("unsupported content type") } @@ -439,6 +446,7 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) { if err := json.Unmarshal(responseData, &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } + modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ @@ -458,6 +466,10 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) { aiPart := ai.NewTextPart(response.Message.Content) modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) } + if response.Message.Thinking != "" { + aiPart := ai.NewReasoningPart(response.Message.Thinking, nil) + modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + } return modelResponse, nil } @@ -504,6 +516,11 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) { chunk.Content = append(chunk.Content, aiPart) } + if response.Message.Thinking != "" { + aiPart := ai.NewReasoningPart(response.Message.Thinking, nil) + chunk.Content = append(chunk.Content, aiPart) + } + return chunk, nil }