stream.go
go
package friendli
import (
	"bufio"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"strings"
)
var (
	// ErrStreamClosed is returned when trying to read from a closed stream
	ErrStreamClosed = errors.New("stream is closed")
	// streamDoneMarker is the marker that indicates the end of the stream
	streamDoneMarker = "[DONE]"
)
// ChatCompletionStream represents a streaming chat completion response
type ChatCompletionStream struct {
	response *http.Response
	reader   *bufio.Reader
	closed   bool
}
// newChatCompletionStream creates a new streaming reader from an HTTP response
func newChatCompletionStream(resp *http.Response) *ChatCompletionStream {
	return &ChatCompletionStream{
		response: resp,
		reader:   bufio.NewReader(resp.Body),
		closed:   false,
	}
}
// Recv reads the next chunk from the stream
func (s *ChatCompletionStream) Recv() (*ChatCompletionChunk, error) {
	if s.closed {
		return nil, ErrStreamClosed
	}
	for {
		line, err := s.reader.ReadString('\n')
		if err != nil {
			if err == io.EOF {
				s.Close()
				return nil, io.EOF
			}
			return nil, fmt.Errorf("failed to read stream: %w", err)
		}
		line = strings.TrimSpace(line)
		// Skip empty lines
		if line == "" {
			continue
		}
		// Check for SSE data prefix
		if !strings.HasPrefix(line, "data: ") {
			continue
		}
		// Extract data content
		data := strings.TrimPrefix(line, "data: ")
		// Check for stream completion marker
		if data == streamDoneMarker {
			s.Close()
			return nil, io.EOF
		}
		// Parse the JSON chunk
		var chunk ChatCompletionChunk
		if err := json.Unmarshal([]byte(data), &chunk); err != nil {
			return nil, fmt.Errorf("failed to unmarshal chunk: %w", err)
		}
		return &chunk, nil
	}
}
// Close closes the stream and releases resources
func (s *ChatCompletionStream) Close() error {
	if s.closed {
		return nil
	}
	s.closed = true
	return s.response.Body.Close()
}
// StreamToChannel reads from the stream and sends chunks to a channel
// This is useful for concurrent processing or integration with HTMX server-sent events
func (s *ChatCompletionStream) StreamToChannel(ch chan<- *ChatCompletionChunk, errCh chan<- error) {
	defer close(ch)
	defer close(errCh)
	for {
		chunk, err := s.Recv()
		if err != nil {
			if err == io.EOF {
				return
			}
			errCh <- err
			return
		}
		ch <- chunk
	}
}
// CollectContent is a helper that reads the entire stream and concatenates all content
// This is useful when you want to use streaming for progress updates but need the full response
func (s *ChatCompletionStream) CollectContent() (string, error) {
	var content strings.Builder
	for {
		chunk, err := s.Recv()
		if err != nil {
			if err == io.EOF {
				break
			}
			return "", err
		}
		if len(chunk.Choices) > 0 {
			content.WriteString(chunk.Choices[0].Delta.Content)
		}
	}
	return content.String(), nil
}
// CollectAll reads the entire stream and returns all chunks
// This is useful for debugging or processing the full stream at once
func (s *ChatCompletionStream) CollectAll() ([]ChatCompletionChunk, error) {
	var chunks []ChatCompletionChunk
	for {
		chunk, err := s.Recv()
		if err != nil {
			if err == io.EOF {
				break
			}
			return nil, err
		}
		chunks = append(chunks, *chunk)
	}
	return chunks, nil
}
No comments yet.