feat: add NATS EventBus implementation (#2385)
Co-authored-by: pozen <pozen@users.noreply.github.com>
This commit is contained in:
@ -285,6 +285,7 @@ require (
|
||||
require (
|
||||
github.com/apache/pulsar-client-go v0.16.0
|
||||
github.com/eino-contrib/ollama v0.1.0
|
||||
github.com/nats-io/nats.go v1.34.1
|
||||
)
|
||||
|
||||
require (
|
||||
@ -312,6 +313,8 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/mtibben/percent v0.2.1 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/nats-io/nkeys v0.4.7 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
|
||||
|
||||
@ -870,9 +870,14 @@ github.com/nats-io/jwt/v2 v2.0.3/go.mod h1:VRP+deawSXyhNjXmxPCHskrR6Mq50BqpEI5SE
|
||||
github.com/nats-io/nats-server/v2 v2.5.0/go.mod h1:Kj86UtrXAL6LwYRA6H4RqzkHhK0Vcv2ZnKD5WbQ1t3g=
|
||||
github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w=
|
||||
github.com/nats-io/nats.go v1.12.1/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
||||
github.com/nats-io/nats.go v1.34.1 h1:syWey5xaNHZgicYBemv0nohUPPmaLteiBEUT6Q5+F/4=
|
||||
github.com/nats-io/nats.go v1.34.1/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
|
||||
github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
||||
github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s=
|
||||
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
|
||||
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
|
||||
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
|
||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus/impl/kafka"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus/impl/nats"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus/impl/nsq"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus/impl/pulsar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus/impl/rmq"
|
||||
@ -57,9 +58,11 @@ func (consumerServiceImpl) RegisterConsumer(nameServer, topic, group string, con
|
||||
return rmq.RegisterConsumer(nameServer, topic, group, consumerHandler, opts...)
|
||||
case "pulsar":
|
||||
return pulsar.RegisterConsumer(nameServer, topic, group, consumerHandler, opts...)
|
||||
case "nats":
|
||||
return nats.RegisterConsumer(nameServer, topic, group, consumerHandler, opts...)
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid mq type: %s , only support nsq, kafka, rmq, pulsar", tp)
|
||||
return fmt.Errorf("invalid mq type: %s , only support nsq, kafka, rmq, pulsar, nats", tp)
|
||||
}
|
||||
|
||||
func NewProducer(nameServer, topic, group string, retries int) (eventbus.Producer, error) {
|
||||
@ -73,9 +76,11 @@ func NewProducer(nameServer, topic, group string, retries int) (eventbus.Produce
|
||||
return rmq.NewProducer(nameServer, topic, group, retries)
|
||||
case "pulsar":
|
||||
return pulsar.NewProducer(nameServer, topic, group)
|
||||
case "nats":
|
||||
return nats.NewProducer(nameServer, topic, group)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid mq type: %s , only support nsq, kafka, rmq, pulsar", tp)
|
||||
return nil, fmt.Errorf("invalid mq type: %s , only support nsq, kafka, rmq, pulsar, nats", tp)
|
||||
}
|
||||
|
||||
func InitResourceEventBusProducer() (eventbus.Producer, error) {
|
||||
|
||||
283
backend/infra/eventbus/impl/nats/consumer.go
Normal file
283
backend/infra/eventbus/impl/nats/consumer.go
Normal file
@ -0,0 +1,283 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/signal"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func RegisterConsumer(serverURL, topic, group string, consumerHandler eventbus.ConsumerHandler, opts ...eventbus.ConsumerOpt) error {
|
||||
// Validate input parameters
|
||||
if serverURL == "" {
|
||||
return fmt.Errorf("NATS server URL is empty")
|
||||
}
|
||||
if topic == "" {
|
||||
return fmt.Errorf("topic is empty")
|
||||
}
|
||||
if group == "" {
|
||||
return fmt.Errorf("group is empty")
|
||||
}
|
||||
if consumerHandler == nil {
|
||||
return fmt.Errorf("consumer handler is nil")
|
||||
}
|
||||
|
||||
// Parse consumer options
|
||||
option := &eventbus.ConsumerOption{}
|
||||
for _, opt := range opts {
|
||||
opt(option)
|
||||
}
|
||||
|
||||
// Prepare connection options
|
||||
natsOptions := []nats.Option{
|
||||
nats.Name(fmt.Sprintf("%s-consumer", group)),
|
||||
nats.ReconnectWait(2 * time.Second),
|
||||
nats.MaxReconnects(-1), // Unlimited reconnects
|
||||
nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
|
||||
logs.Warnf("NATS consumer disconnected: %v", err)
|
||||
}),
|
||||
nats.ReconnectHandler(func(nc *nats.Conn) {
|
||||
logs.Infof("NATS consumer reconnected to %s", nc.ConnectedUrl())
|
||||
}),
|
||||
}
|
||||
|
||||
// Add authentication support
|
||||
if err := addAuthentication(&natsOptions); err != nil {
|
||||
return fmt.Errorf("setup authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Create NATS connection
|
||||
nc, err := nats.Connect(serverURL, natsOptions...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create NATS connection failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if JetStream is enabled
|
||||
useJetStream := os.Getenv(consts.NATSUseJetStream) == "true"
|
||||
|
||||
// Create cancellable context for better resource management
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
if useJetStream {
|
||||
// Use JetStream for persistent messaging
|
||||
err = startJetStreamConsumer(ctx, nc, topic, group, consumerHandler)
|
||||
} else {
|
||||
// Use core NATS for simple pub/sub
|
||||
err = startCoreConsumer(ctx, nc, topic, group, consumerHandler)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
cancel() // Cancel context to prevent leak
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle graceful shutdown
|
||||
safego.Go(context.Background(), func() {
|
||||
signal.WaitExit()
|
||||
logs.Infof("shutting down NATS consumer for topic: %s, group: %s", topic, group)
|
||||
cancel() // Cancel the context to stop consumer loop
|
||||
nc.Close()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startJetStreamConsumer starts a JetStream-based consumer for persistent messaging
|
||||
func startJetStreamConsumer(ctx context.Context, nc *nats.Conn, topic, group string, consumerHandler eventbus.ConsumerHandler) error {
|
||||
// Create JetStream context
|
||||
js, err := nc.JetStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create JetStream context failed: %w", err)
|
||||
}
|
||||
|
||||
// Ensure Stream exists
|
||||
if err := ensureStream(js, topic); err != nil {
|
||||
return fmt.Errorf("ensure stream failed: %w", err)
|
||||
}
|
||||
|
||||
// Start consuming messages in a goroutine
|
||||
safego.Go(ctx, func() {
|
||||
defer nc.Close()
|
||||
|
||||
// Create durable pull subscription
|
||||
sub, err := js.PullSubscribe(topic, group)
|
||||
if err != nil {
|
||||
logs.Errorf("create NATS JetStream subscription failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logs.Infof("NATS JetStream consumer stopped for topic: %s, group: %s", topic, group)
|
||||
return
|
||||
default:
|
||||
// Fetch one message at a time for better control and resource management
|
||||
msgs, err := sub.Fetch(1, nats.MaxWait(1*time.Second))
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
// Handle timeout and other non-fatal errors
|
||||
if err == nats.ErrTimeout {
|
||||
continue
|
||||
}
|
||||
logs.Errorf("fetch NATS JetStream message error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the single message
|
||||
if len(msgs) > 0 {
|
||||
msg := msgs[0]
|
||||
eventMsg := &eventbus.Message{
|
||||
Topic: topic,
|
||||
Group: group,
|
||||
Body: msg.Data,
|
||||
}
|
||||
|
||||
// Handle message with context
|
||||
if err := consumerHandler.HandleMessage(ctx, eventMsg); err != nil {
|
||||
logs.Errorf("handle NATS JetStream message failed, topic: %s, group: %s, err: %v", topic, group, err)
|
||||
// Negative acknowledge on error
|
||||
msg.Nak()
|
||||
continue
|
||||
}
|
||||
|
||||
// Acknowledge message on success
|
||||
msg.Ack()
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startCoreConsumer starts a core NATS consumer for simple pub/sub
|
||||
func startCoreConsumer(ctx context.Context, nc *nats.Conn, topic, group string, consumerHandler eventbus.ConsumerHandler) error {
|
||||
// Start consuming messages in a goroutine
|
||||
safego.Go(ctx, func() {
|
||||
defer nc.Close()
|
||||
|
||||
// Create queue subscription for load balancing
|
||||
sub, err := nc.QueueSubscribe(topic, group, func(msg *nats.Msg) {
|
||||
eventMsg := &eventbus.Message{
|
||||
Topic: topic,
|
||||
Group: group,
|
||||
Body: msg.Data,
|
||||
}
|
||||
|
||||
// Handle message with context
|
||||
if err := consumerHandler.HandleMessage(ctx, eventMsg); err != nil {
|
||||
logs.Errorf("handle NATS core message failed, topic: %s, group: %s, err: %v", topic, group, err)
|
||||
// For core NATS, we can't nack, just log the error
|
||||
return
|
||||
}
|
||||
|
||||
logs.Debugf("successfully processed NATS core message, topic: %s, group: %s", topic, group)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logs.Errorf("create NATS core subscription failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Wait for context cancellation
|
||||
<-ctx.Done()
|
||||
logs.Infof("NATS core consumer stopped for topic: %s, group: %s", topic, group)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addAuthentication adds authentication options to NATS connection
|
||||
func addAuthentication(options *[]nats.Option) error {
|
||||
// JWT authentication with NKey
|
||||
if jwtToken := os.Getenv(consts.NATSJWTToken); jwtToken != "" {
|
||||
nkeySeed := os.Getenv(consts.NATSNKeySeed)
|
||||
if nkeySeed == "" {
|
||||
return fmt.Errorf("NATS_NKEY_SEED is required when using JWT authentication")
|
||||
}
|
||||
*options = append(*options, nats.UserJWTAndSeed(jwtToken, nkeySeed))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Username/password authentication
|
||||
if username := os.Getenv(consts.NATSUsername); username != "" {
|
||||
password := os.Getenv(consts.NATSPassword)
|
||||
*options = append(*options, nats.UserInfo(username, password))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Token authentication
|
||||
if token := os.Getenv(consts.NATSToken); token != "" {
|
||||
*options = append(*options, nats.Token(token))
|
||||
return nil
|
||||
}
|
||||
|
||||
// No authentication configured
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureStream ensures that a JetStream stream exists for the given subject
|
||||
func ensureStream(js nats.JetStreamContext, subject string) error {
|
||||
// Replace dots and other invalid characters with underscores for stream name
|
||||
// NATS stream names cannot contain dots, spaces, or other special characters
|
||||
streamName := strings.ReplaceAll(subject, ".", "_") + "_STREAM"
|
||||
|
||||
// Check if Stream already exists
|
||||
_, err := js.StreamInfo(streamName)
|
||||
if err == nil {
|
||||
return nil // Stream already exists
|
||||
}
|
||||
|
||||
// Only create stream if it's specifically a "stream not found" error
|
||||
if err != nats.ErrStreamNotFound {
|
||||
return fmt.Errorf("failed to check stream %s: %w", streamName, err)
|
||||
}
|
||||
|
||||
// Create Stream if it doesn't exist
|
||||
_, err = js.AddStream(&nats.StreamConfig{
|
||||
Name: streamName,
|
||||
Subjects: []string{subject},
|
||||
Storage: nats.FileStorage, // File storage for persistence
|
||||
MaxAge: 24 * time.Hour, // Retain messages for 24 hours
|
||||
MaxMsgs: 1000000, // Maximum number of messages
|
||||
MaxBytes: 1024 * 1024 * 1024, // Maximum storage size (1GB)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stream %s: %w", streamName, err)
|
||||
}
|
||||
|
||||
logs.Infof("created NATS JetStream stream: %s", streamName)
|
||||
return nil
|
||||
}
|
||||
275
backend/infra/eventbus/impl/nats/nats_test.go
Normal file
275
backend/infra/eventbus/impl/nats/nats_test.go
Normal file
@ -0,0 +1,275 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
var serviceURL = "nats://localhost:4222"
|
||||
|
||||
func TestNATSProducer(t *testing.T) {
|
||||
if os.Getenv("NATS_LOCAL_TEST") != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
// Set up NATS connection options
|
||||
opts := []nats.Option{nats.Name("test-producer")}
|
||||
|
||||
// Add authentication if provided
|
||||
if jwtToken := os.Getenv(consts.NATSJWTToken); jwtToken != "" {
|
||||
opts = append(opts, nats.UserJWT(func() (string, error) {
|
||||
return jwtToken, nil
|
||||
}, func(nonce []byte) ([]byte, error) {
|
||||
return []byte(os.Getenv(consts.NATSNKeySeed)), nil
|
||||
}))
|
||||
} else if username := os.Getenv(consts.NATSUsername); username != "" {
|
||||
password := os.Getenv(consts.NATSPassword)
|
||||
opts = append(opts, nats.UserInfo(username, password))
|
||||
} else if token := os.Getenv(consts.NATSToken); token != "" {
|
||||
opts = append(opts, nats.Token(token))
|
||||
}
|
||||
|
||||
nc, err := nats.Connect(serviceURL, opts...)
|
||||
assert.NoError(t, err)
|
||||
defer nc.Close()
|
||||
|
||||
// Test core NATS publishing
|
||||
err = nc.Publish("test.subject", []byte("hello from core NATS"))
|
||||
assert.NoError(t, err)
|
||||
t.Log("Message sent via core NATS")
|
||||
|
||||
// Test JetStream publishing if enabled
|
||||
if os.Getenv(consts.NATSUseJetStream) == "true" {
|
||||
js, err := nc.JetStream()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure stream exists
|
||||
_, err = js.AddStream(&nats.StreamConfig{
|
||||
Name: "TEST_STREAM",
|
||||
Subjects: []string{"test.jetstream.>"},
|
||||
})
|
||||
if err != nil && err != nats.ErrStreamNameAlreadyInUse {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = js.Publish("test.jetstream.subject", []byte("hello from JetStream"))
|
||||
assert.NoError(t, err)
|
||||
t.Log("Message sent via JetStream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNATSConsumer(t *testing.T) {
|
||||
if os.Getenv("NATS_LOCAL_TEST") != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
// Set up NATS connection options
|
||||
opts := []nats.Option{nats.Name("test-consumer")}
|
||||
|
||||
// Add authentication if provided
|
||||
if jwtToken := os.Getenv(consts.NATSJWTToken); jwtToken != "" {
|
||||
opts = append(opts, nats.UserJWT(func() (string, error) {
|
||||
return jwtToken, nil
|
||||
}, func(nonce []byte) ([]byte, error) {
|
||||
return []byte(os.Getenv(consts.NATSNKeySeed)), nil
|
||||
}))
|
||||
} else if username := os.Getenv(consts.NATSUsername); username != "" {
|
||||
password := os.Getenv(consts.NATSPassword)
|
||||
opts = append(opts, nats.UserInfo(username, password))
|
||||
} else if token := os.Getenv(consts.NATSToken); token != "" {
|
||||
opts = append(opts, nats.Token(token))
|
||||
}
|
||||
|
||||
nc, err := nats.Connect(serviceURL, opts...)
|
||||
assert.NoError(t, err)
|
||||
defer nc.Close()
|
||||
|
||||
// Test core NATS subscription
|
||||
t.Run("CoreNATSConsumer", func(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
// Subscribe to messages
|
||||
sub, err := nc.QueueSubscribe("test.subject", "test-queue", func(msg *nats.Msg) {
|
||||
defer wg.Done()
|
||||
t.Logf("Received core NATS message: %s", string(msg.Data))
|
||||
assert.Equal(t, "hello from core NATS", string(msg.Data))
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Send a test message
|
||||
err = nc.Publish("test.subject", []byte("hello from core NATS"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for message with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout waiting for core NATS message")
|
||||
}
|
||||
})
|
||||
|
||||
// Test JetStream subscription if enabled
|
||||
if os.Getenv(consts.NATSUseJetStream) == "true" {
|
||||
t.Run("JetStreamConsumer", func(t *testing.T) {
|
||||
js, err := nc.JetStream()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure stream exists
|
||||
_, err = js.AddStream(&nats.StreamConfig{
|
||||
Name: "TEST_STREAM",
|
||||
Subjects: []string{"test.jetstream.>"},
|
||||
})
|
||||
if err != nil && err != nats.ErrStreamNameAlreadyInUse {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
// Subscribe to JetStream messages
|
||||
sub, err := js.PullSubscribe("test.jetstream.subject", "test-consumer")
|
||||
assert.NoError(t, err)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Send a test message
|
||||
_, err = js.Publish("test.jetstream.subject", []byte("hello from JetStream"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
msgs, err := sub.Fetch(1, nats.MaxWait(5*time.Second))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to fetch JetStream message: %v", err)
|
||||
return
|
||||
}
|
||||
if len(msgs) > 0 {
|
||||
msg := msgs[0]
|
||||
t.Logf("Received JetStream message: %s", string(msg.Data))
|
||||
assert.Equal(t, "hello from JetStream", string(msg.Data))
|
||||
msg.Ack()
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for message with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Error("Timeout waiting for JetStream message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNATSProducerImpl(t *testing.T) {
|
||||
if os.Getenv("NATS_LOCAL_TEST") != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
producer, err := NewProducer(serviceURL, "test.topic", "test-group")
|
||||
assert.NoError(t, err)
|
||||
// Note: eventbus.Producer interface doesn't have Close method
|
||||
// The underlying connection will be closed when the producer is garbage collected
|
||||
|
||||
// Test single message send
|
||||
err = producer.Send(context.Background(), []byte("single message test"))
|
||||
assert.NoError(t, err)
|
||||
t.Log("Single message sent successfully")
|
||||
|
||||
// Test batch message send
|
||||
messages := [][]byte{
|
||||
[]byte("batch message 1"),
|
||||
[]byte("batch message 2"),
|
||||
}
|
||||
|
||||
err = producer.BatchSend(context.Background(), messages)
|
||||
assert.NoError(t, err)
|
||||
t.Log("Batch messages sent successfully")
|
||||
}
|
||||
|
||||
func TestNATSConsumerImpl(t *testing.T) {
|
||||
if os.Getenv("NATS_LOCAL_TEST") != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a test message handler
|
||||
messageReceived := make(chan *eventbus.Message, 1)
|
||||
handler := &testHandler{
|
||||
messageReceived: messageReceived,
|
||||
t: t,
|
||||
}
|
||||
|
||||
// Register consumer
|
||||
err := RegisterConsumer(serviceURL, "test.consumer.impl", "test-group", handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Send a test message using producer
|
||||
producer, err := NewProducer(serviceURL, "test.consumer.impl", "test-group")
|
||||
assert.NoError(t, err)
|
||||
// Note: eventbus.Producer interface doesn't have Close method
|
||||
// The underlying connection will be closed when the producer is garbage collected
|
||||
|
||||
err = producer.Send(context.Background(), []byte("consumer implementation test"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for message to be received
|
||||
select {
|
||||
case receivedMsg := <-messageReceived:
|
||||
assert.Equal(t, []byte("consumer implementation test"), receivedMsg.Body)
|
||||
t.Log("Consumer implementation test passed")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Error("Timeout waiting for message in consumer implementation test")
|
||||
}
|
||||
}
|
||||
|
||||
// testHandler implements eventbus.ConsumerHandler for testing
|
||||
type testHandler struct {
|
||||
messageReceived chan *eventbus.Message
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleMessage(ctx context.Context, message *eventbus.Message) error {
|
||||
h.t.Logf("Handler received message: %s", string(message.Body))
|
||||
h.messageReceived <- message
|
||||
return nil
|
||||
}
|
||||
243
backend/infra/eventbus/impl/nats/producer.go
Normal file
243
backend/infra/eventbus/impl/nats/producer.go
Normal file
@ -0,0 +1,243 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type producerImpl struct {
|
||||
nc *nats.Conn
|
||||
js nats.JetStreamContext
|
||||
useJetStream bool
|
||||
topic string // Store the topic for this producer instance
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProducer creates a new NATS producer
|
||||
func NewProducer(serverURL, topic, group string) (eventbus.Producer, error) {
|
||||
if serverURL == "" {
|
||||
return nil, fmt.Errorf("server URL is empty")
|
||||
}
|
||||
|
||||
if topic == "" {
|
||||
return nil, fmt.Errorf("topic is empty")
|
||||
}
|
||||
|
||||
// Set up NATS connection options
|
||||
opts := []nats.Option{
|
||||
nats.Name("coze-studio-producer"),
|
||||
nats.MaxReconnects(-1), // Unlimited reconnects
|
||||
}
|
||||
|
||||
// Add authentication if provided
|
||||
if jwtToken := os.Getenv(consts.NATSJWTToken); jwtToken != "" {
|
||||
nkeySeed := os.Getenv(consts.NATSNKeySeed)
|
||||
opts = append(opts, nats.UserJWTAndSeed(jwtToken, nkeySeed))
|
||||
} else if username := os.Getenv(consts.NATSUsername); username != "" {
|
||||
password := os.Getenv(consts.NATSPassword)
|
||||
opts = append(opts, nats.UserInfo(username, password))
|
||||
} else if token := os.Getenv(consts.NATSToken); token != "" {
|
||||
opts = append(opts, nats.Token(token))
|
||||
}
|
||||
|
||||
// Connect to NATS
|
||||
nc, err := nats.Connect(serverURL, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to NATS failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if JetStream should be used
|
||||
useJetStream := os.Getenv(consts.NATSUseJetStream) == "true"
|
||||
|
||||
producer := &producerImpl{
|
||||
nc: nc,
|
||||
useJetStream: useJetStream,
|
||||
topic: topic, // Store the topic for this producer instance
|
||||
closed: false,
|
||||
}
|
||||
|
||||
// Initialize JetStream if needed
|
||||
if useJetStream {
|
||||
js, err := nc.JetStream()
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
return nil, fmt.Errorf("create JetStream context failed: %w", err)
|
||||
}
|
||||
producer.js = js
|
||||
}
|
||||
|
||||
return producer, nil
|
||||
}
|
||||
|
||||
// Send sends a single message using the stored topic
|
||||
func (p *producerImpl) Send(ctx context.Context, body []byte, opts ...eventbus.SendOpt) error {
|
||||
return p.BatchSend(ctx, [][]byte{body}, opts...)
|
||||
}
|
||||
|
||||
// BatchSend sends multiple messages using the stored topic
|
||||
func (p *producerImpl) BatchSend(ctx context.Context, bodyArr [][]byte, opts ...eventbus.SendOpt) error {
|
||||
p.mu.RLock()
|
||||
if p.closed {
|
||||
p.mu.RUnlock()
|
||||
return fmt.Errorf("producer is closed")
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
if len(bodyArr) == 0 {
|
||||
return fmt.Errorf("no messages to send")
|
||||
}
|
||||
|
||||
// Use the stored topic
|
||||
topic := p.topic
|
||||
if topic == "" {
|
||||
return fmt.Errorf("topic is not set")
|
||||
}
|
||||
|
||||
// Parse producer options
|
||||
option := &eventbus.SendOption{}
|
||||
for _, opt := range opts {
|
||||
opt(option)
|
||||
}
|
||||
|
||||
if p.useJetStream {
|
||||
return p.batchSendJetStream(ctx, topic, bodyArr, option)
|
||||
} else {
|
||||
return p.batchSendCore(ctx, topic, bodyArr, option)
|
||||
}
|
||||
}
|
||||
|
||||
// batchSendJetStream sends messages using JetStream for persistence
|
||||
func (p *producerImpl) batchSendJetStream(ctx context.Context, topic string, messages [][]byte, option *eventbus.SendOption) error {
|
||||
// Ensure Stream exists
|
||||
if err := ensureStream(p.js, topic); err != nil {
|
||||
return fmt.Errorf("ensure stream failed: %w", err)
|
||||
}
|
||||
|
||||
// Use TaskGroup to wait for all async publishes
|
||||
tg := taskgroup.NewTaskGroup(ctx, min(len(messages), 5))
|
||||
|
||||
for i, message := range messages {
|
||||
tg.Go(func() error {
|
||||
// Prepare publish options
|
||||
pubOpts := []nats.PubOpt{}
|
||||
|
||||
// Add message ID for deduplication if sharding key is provided
|
||||
if option.ShardingKey != nil && *option.ShardingKey != "" {
|
||||
msgID := fmt.Sprintf("%s-%d", *option.ShardingKey, i)
|
||||
pubOpts = append(pubOpts, nats.MsgId(msgID))
|
||||
}
|
||||
|
||||
// Add context for timeout
|
||||
pubOpts = append(pubOpts, nats.Context(ctx))
|
||||
|
||||
// Publish message asynchronously
|
||||
_, err := p.js.Publish(topic, message, pubOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish message %d failed: %w", i, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all messages to be sent
|
||||
if err := tg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.Debugf("successfully sent %d messages to NATS JetStream topic: %s", len(messages), topic)
|
||||
return nil
|
||||
}
|
||||
|
||||
// batchSendCore sends messages using core NATS for simple pub/sub
|
||||
func (p *producerImpl) batchSendCore(ctx context.Context, topic string, messages [][]byte, option *eventbus.SendOption) error {
|
||||
// Use TaskGroup to wait for all async publishes
|
||||
tg := taskgroup.NewTaskGroup(ctx, min(len(messages), 5))
|
||||
|
||||
for i, message := range messages {
|
||||
tg.Go(func() error {
|
||||
// For core NATS, we can add headers if sharding key is provided
|
||||
if option.ShardingKey != nil && *option.ShardingKey != "" {
|
||||
// Create message with headers
|
||||
natsMsg := &nats.Msg{
|
||||
Subject: topic,
|
||||
Data: message,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
natsMsg.Header.Set("Sharding-Key", *option.ShardingKey)
|
||||
|
||||
err := p.nc.PublishMsg(natsMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish message %d with header failed: %w", i, err)
|
||||
}
|
||||
} else {
|
||||
// Simple publish without headers
|
||||
err := p.nc.Publish(topic, message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish message %d failed: %w", i, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all messages to be sent
|
||||
if err := tg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush to ensure all messages are sent
|
||||
if err := p.nc.Flush(); err != nil {
|
||||
return fmt.Errorf("flush NATS connection failed: %w", err)
|
||||
}
|
||||
logs.Debugf("successfully sent %d messages to NATS core topic: %s", len(messages), topic)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the producer and releases resources
|
||||
func (p *producerImpl) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closed = true
|
||||
|
||||
if p.nc != nil {
|
||||
// Drain connection to ensure all pending messages are sent
|
||||
if err := p.nc.Drain(); err != nil {
|
||||
logs.Warnf("drain NATS connection failed: %v", err)
|
||||
}
|
||||
p.nc.Close()
|
||||
}
|
||||
|
||||
logs.Infof("NATS producer closed successfully")
|
||||
return nil
|
||||
}
|
||||
@ -60,6 +60,12 @@ const (
|
||||
RMQAccessKey = "RMQ_ACCESS_KEY"
|
||||
PulsarServiceURL = "PULSAR_SERVICE_URL"
|
||||
PulsarJWTToken = "PULSAR_JWT_TOKEN"
|
||||
NATSJWTToken = "NATS_JWT_TOKEN"
|
||||
NATSNKeySeed = "NATS_NKEY_SEED"
|
||||
NATSUsername = "NATS_USERNAME"
|
||||
NATSPassword = "NATS_PASSWORD"
|
||||
NATSToken = "NATS_TOKEN"
|
||||
NATSUseJetStream = "NATS_USE_JETSTREAM"
|
||||
RMQTopicApp = "opencoze_search_app"
|
||||
RMQTopicResource = "opencoze_search_resource"
|
||||
RMQTopicKnowledge = "opencoze_knowledge"
|
||||
|
||||
Reference in New Issue
Block a user