NATS and SSE

Recently Synaida has released their preview of their HTTP Gateway. This allows for clients to use NATS who otherwise wouldn’t have the ability.

This is a cool idea and I figured I’d take a stab at building a small PoC that would offer similar capabilities.

Right now, Synadia’s HTTP Gateway supports KV and pub/sub. I wanted to take this a step further and leverage Jetstream under the hood.

Types

First let’s define some types for our app. We are going to create a map of Jetstream policy types. Then we are going to define an AppContext we can use to pass our Jetstream context into our handler. Then we can define an AppHandlerFunc which takes in the AppContext. Finally let’s define a struct that we can use to define our consumer options.

var policies = map[string]jetstream.DeliverPolicy{
        "new":  jetstream.DeliverNewPolicy,
        "all":  jetstream.DeliverAllPolicy,
        "last": jetstream.DeliverLastPolicy,
}

type AppContext struct {
        js jetstream.JetStream
}

type AppHandlerFunc func(http.ResponseWriter, *http.Request, AppContext)

type ConsumerOpts struct {
        id     string
        policy jetstream.DeliverPolicy
}

Consumer

Now let’s create a method that builds our ephemeral consumer for us. This will assume you have a stream named Events. It will return a consumer iterator that we can range over later on. We will also create a function that ranges over messages from the iterator and puts them on a channel.

func (a AppContext) newConsumer(ctx context.Context, opts ConsumerOpts) (jetstream.MessagesContext, error) {
        subject := fmt.Sprintf("events.%s", opts.id)
        config := jetstream.ConsumerConfig{
                DeliverPolicy:     opts.policy,
                FilterSubject:     subject,
                InactiveThreshold: 1 * time.Second,
        }

        con, err := a.js.CreateOrUpdateConsumer(ctx, "Events", config)
        if err != nil {
                return nil, err
        }

        return con.Messages()
}

func handleMessages(ctx context.Context, it jetstream.MessagesContext, ch chan<- string) {
        for {
                select {
                case <-ctx.Done():
                        return
                default:
                        m, err := it.Next()
                        if err != nil && errors.Is(err, jetstream.ErrMsgIteratorClosed) {
                                ctx.Done()
                                continue
                        }
                        if err != nil {
                                log.Println(err)
                                continue
                        }

                        data := fmt.Sprintf("%s", m.Data())
                        ch <- data

                        m.Ack()
                }
        }

}

Ping

Here we are going to write a function that will send a ping message every 30 seconds to the client to ensure the connection is kept alive.

func ping(ctx context.Context, ch chan<- string) {
        for {
                select {
                case <-ctx.Done():
                        return
                default:
                        time.Sleep(30 * time.Second)
                        ch <- fmt.Sprint(`{"system": "ping"}`)

                }
        }

}

Handler

Finally let’s build our handler and our wrapper to satisfy http.Serve. We are using a header here to define the user ID, this could very well come from a JWT or some other method. We set the content type to be text/event-stream, set our connection to keep-alive, and our cache-control to no-cache.

We also allow for a query param named deliver. This can be set to new, all, or last and control our ephemeral consumer delivery policy.

We ensure that we can treat our ResponseWriter as an http.Flusher and then create our consumer and iterator. We can then create a channel of strings and set up some goroutines. One anonymous goroutine to handle draining our iterator, another to handle sending our ping messages, and a final goroutine to iterate over messages from our consumer. These are then passed over a channel back to our handler which ranges over those messages and returns them to the client.


func AppHandler(h AppHandlerFunc, a AppContext) http.HandlerFunc {
        return func(w http.ResponseWriter, r *http.Request) {
                h(w, r, a)
        }

}

func handleEvents(w http.ResponseWriter, r *http.Request, app AppContext) {
        w.Header().Set("Content-Type", "text/event-stream")
        w.Header().Set("Cache-Control", "no-cache")
        w.Header().Set("Connection", "keep-alive")
        id := r.Header.Get("user_id")
        policy := r.URL.Query().Get("deliver")

        if policy == "" {
                policy = "new"
        }

        p, ok := policies[policy]
        if !ok {
                http.Error(w, "deliver must be new, all, or last", 400)
                return
        }

        flusher, ok := w.(http.Flusher)
        if !ok {
                http.Error(w, "internal server error", 500)
                return
        }

        it, err := app.newConsumer(r.Context(), ConsumerOpts{
                id:     id,
                policy: p,
        })
        if err != nil {
                http.Error(w, err.Error(), 500)
        }

        ch := make(chan string)

        go func() {
                <-r.Context().Done()
                it.Drain()
        }()

        go ping(r.Context(), ch)
        go handleMessages(r.Context(), it, ch)

        for v := range ch {
                data := fmt.Sprintf("%s\n", v)
                _, err = fmt.Fprintf(w, data)
                if err != nil {
                        http.Error(w, err.Error(), 500)
                        return
                }

                flusher.Flush()
        }
}

Main

Here’s our main function to start everything. Here we are setting up our connection with our NATS context, getting our Jetstream context, and then building our AppContext and our server.

func main() {
        var err error
        opts := []nats.Option{nats.Name("ssetest")}
        nc, err := natscontext.Connect("", opts...)
        if err != nil {
                log.Fatal(err)
        }

        js, err := jetstream.New(nc)
        if err != nil {
                log.Fatal(err)
        }

        app := AppContext{
                js: js,
        }

        r := http.NewServeMux()

        r.Handle("GET /api/v2/events", AppHandler(handleEvents, app))

        log.Fatal(http.ListenAndServe(":8080", r))
}

Getting events

To get events from our server, we can use cURL with -N.

curl -N http://localhost:8080/api/v1/events?deliver=new -H "user_id: jhooks"

NATS SSE Demo

Full Example

package main

import (
        "context"
        "errors"
        "fmt"
        "log"
        "net/http"
        "time"

        "github.com/nats-io/jsm.go/natscontext"
        "github.com/nats-io/nats.go"
        "github.com/nats-io/nats.go/jetstream"
)

var policies = map[string]jetstream.DeliverPolicy{
        "new":  jetstream.DeliverNewPolicy,
        "all":  jetstream.DeliverAllPolicy,
        "last": jetstream.DeliverLastPolicy,
}

type AppContext struct {
        js jetstream.JetStream
}

type AppHandlerFunc func(http.ResponseWriter, *http.Request, AppContext)

type ConsumerOpts struct {
        id     string
        policy jetstream.DeliverPolicy
}

func AppHandler(h AppHandlerFunc, a AppContext) http.HandlerFunc {
        return func(w http.ResponseWriter, r *http.Request) {
                h(w, r, a)
        }

}

func (a AppContext) newConsumer(ctx context.Context, opts ConsumerOpts) (jetstream.MessagesContext, error) {
        subject := fmt.Sprintf("events.%s", opts.id)
        config := jetstream.ConsumerConfig{
                Name:              opts.id,
                DeliverPolicy:     opts.policy,
                FilterSubject:     subject,
                InactiveThreshold: 1 * time.Second,
        }

        con, err := a.js.CreateOrUpdateConsumer(ctx, "Events", config)
        if err != nil {
                return nil, err
        }

        return con.Messages()
}

func handleEvents(w http.ResponseWriter, r *http.Request, app AppContext) {
        w.Header().Set("Content-Type", "text/event-stream")
        w.Header().Set("Cache-Control", "no-cache")
        w.Header().Set("Connection", "keep-alive")
        id := r.Header.Get("agent_id")
        policy := r.URL.Query().Get("deliver")

        if policy == "" {
                policy = "new"
        }

        p, ok := policies[policy]
        if !ok {
                http.Error(w, "deliver must be new, all, or last", 400)
                return
        }

        flusher, ok := w.(http.Flusher)
        if !ok {
                http.Error(w, "internal server error", 500)
                return
        }

        it, err := app.newConsumer(r.Context(), ConsumerOpts{
                id:     id,
                policy: p,
        })
        if err != nil {
                http.Error(w, err.Error(), 500)
        }

        ch := make(chan string)

        go func() {
                <-r.Context().Done()
                it.Drain()
        }()

        go ping(r.Context(), ch)
        go handleMessages(r.Context(), it, ch)

        for v := range ch {
                data := fmt.Sprintf("%s\n", v)
                _, err = fmt.Fprintf(w, data)
                if err != nil {
                        http.Error(w, err.Error(), 500)
                        return
                }

                flusher.Flush()
        }
}

func ping(ctx context.Context, ch chan<- string) {
        for {
                select {
                case <-ctx.Done():
                        return
                default:
                        time.Sleep(30 * time.Second)
                        ch <- fmt.Sprint(`{"system": "ping"}`)

                }
        }

}

func handleMessages(ctx context.Context, it jetstream.MessagesContext, ch chan<- string) {
        for {
                select {
                case <-ctx.Done():
                        return
                default:
                        m, err := it.Next()
                        if err != nil && errors.Is(err, jetstream.ErrMsgIteratorClosed) {
                                ctx.Done()
                                continue
                        }
                        if err != nil {
                                log.Println(err)
                                continue
                        }

                        data := fmt.Sprintf("%s", m.Data())
                        ch <- data

                        m.Ack()
                }
        }

}

func main() {
        var err error
        opts := []nats.Option{nats.Name("ssetest")}
        nc, err := natscontext.Connect("", opts...)
        if err != nil {
                log.Fatal(err)
        }

        js, err := jetstream.New(nc)
        if err != nil {
                log.Fatal(err)
        }

        app := AppContext{
                js: js,
        }

        r := http.NewServeMux()

        r.Handle("GET /api/v1/events", AppHandler(handleEvents, app))

        log.Fatal(http.ListenAndServe(":8080", r))
}