package main

import (
	"fmt"
	"net"
	"strings"
	"time"
)

const (
	RINGBUFSIZE    = 1 << 16
	MAX_DGRAM_SIZE = 1024
)

type StatsdListener struct {
	port int // to match UDPAddr.Port
	conn *net.UDPConn
	// processing queue, a.k.a ringbuffer
	pktqueue  []string
	read_pos  uint64
	write_pos uint64
	ready     bool
	// for performance inspection
	total_wait_us         uint64
	times_waited_on_read  uint64
	times_waited_on_write uint64

	verbose bool

	collector *StatsCollector
}

func NewListener(port uint16) *StatsdListener {
	s := new(StatsdListener)
	s.port = int(port)

	s.pktqueue = make([]string, RINGBUFSIZE, RINGBUFSIZE)
	s.read_pos = 0
	s.write_pos = 0
	s.ready = true

	s.verbose = false

	s.collector = NewStatsCollector()

	return s
}

func (s *StatsdListener) Start() (bool, error) {
	var err error

	laddr := new(net.UDPAddr)
	laddr.IP = net.IPv4(0, 0, 0, 0)
	laddr.Port = s.port

	s.conn, err = net.ListenUDP("udp", laddr)
	if err != nil {
		panic(fmt.Sprintf("Failed to allocate socket: %v", err))
	}
	s.conn.SetReadBuffer(1 << 16)
	go s.keepReading()
	go s.keepProcessing()
	go s.keepFlipping()

	return true, nil
}

func (s *StatsdListener) Stop() {
	// Stop accepting new packets, drain the queue, stop
	var last_msg string
	s.conn.Close()
	s.ready = false
	for {
		m, err := s.DequeueMsg()
		if err != nil {
			last_msg = m
			break
		}
		// process msg
	}
	m := fmt.Sprintf("Final transmission: %s\n", last_msg)
	s.maybe_log(m)
}

// Pushing messages is always a batch operation, even if batch size is 1
func (s *StatsdListener) EnqueueMsgs(msgs []string) error {
	if s.ready != true {
		return fmt.Errorf("Not accepting messages")
	}
	// Bail out early if we're at capacity
	if s.write_pos >= s.read_pos+RINGBUFSIZE {
		return fmt.Errorf("Ringbuffer backed up, consume messages first (wpos=%d, rpos=%d)", s.write_pos, s.read_pos)
	}
	for _, msg := range msgs {
		s.pktqueue[s.write_pos%RINGBUFSIZE] = msg
		s.write_pos++
	}
	return nil
}

func (s *StatsdListener) DequeueMsg() (string, error) {
	if s.read_pos >= s.write_pos {
		return "", fmt.Errorf("Nothing to read")
	}
	msg := s.pktqueue[s.read_pos%RINGBUFSIZE]
	s.read_pos++
	return msg, nil
}

// This method must only be invoked as a goroutine
func (s *StatsdListener) keepProcessing() {
	for {
		m, err := s.DequeueMsg()
		if err != nil {
			// We've drained the buffer, wait a bit
			s.times_waited_on_read++
			s.total_wait_us += 600
			time.Sleep(600 * time.Microsecond)
			continue
		}
		statsd_msg := ParseStatsdMessage(m)
		if statsd_msg == nil {
			continue
		}
		s.processMessage(statsd_msg)
	}
}

// XXX: Should these workhorses live in in their own file, to make it
// really easy to isolate and work with?
func (s *StatsdListener) processMessage(msg *StatsdMessage) {
	s.maybe_log(fmt.Sprintf("%v\n", msg))
	// Let's start with counter type only. It's simple.
	switch mtype := msg.Type; mtype {

	case STATSD_TYPE_COUNT:
		{
			c := new(Counter)
			c.KeyName = msg.Name
			c.Value = uint64(msg.Value) // We have floats coming off the wire
			c.Tags = msg.Tags
			s.collector.mutex.Lock()
			s.processCounter(c)
			s.collector.mutex.Unlock()
			return
		}
	case STATSD_TYPE_GAUGE:
		{
			g := new(Gauge)
			g.KeyName = msg.Name
			g.Value = uint64(msg.Value) // Ditto
			g.Tags = msg.Tags
			s.collector.mutex.Lock()
			s.processGauge(g)
			s.collector.mutex.Unlock()
			return
		}
	case STATSD_TYPE_HISTOGRAM:
		{
			h := new(HistogramData)
			h.KeyName = msg.Name
			h.Value = uint64(msg.Value) // And again
			h.Tags = msg.Tags
			s.collector.mutex.Lock()
			s.processHistogram(h)
			s.collector.mutex.Unlock()
		}
	default:
		{
			s.maybe_log(fmt.Sprintf("Unknown message type: %v\n", mtype))
			return
		}
	}
}

func (s *StatsdListener) processCounter(c *Counter) {
	counts := s.collector.data.Counters
	ct, ok := counts[c.KeyName]
	if !ok {
		counts[c.KeyName] = c
	} else {
		ct.Value += c.Value
	}
}

func (s *StatsdListener) processGauge(g *Gauge) {
	gauges := s.collector.data.Gauges
	gauges[g.KeyName] = g
}

func (s *StatsdListener) processHistogram(h *HistogramData) {
	histograms := s.collector.data.Histograms
	item, ok := histograms[h.KeyName]
	if !ok {
		val := NewBucketedCounter()
		val.Tags = h.Tags
		val.Include(h.Value)
		histograms[h.KeyName] = val
	} else {
		item.Include(h.Value)
	}
}

// This method must only be invoked as a goroutine
func (s *StatsdListener) keepReading() {
	// We should be able to do with 512 since this is UDP, but let's keep
	// the numbers round
	buf := make([]byte, MAX_DGRAM_SIZE)
	for {
		if !s.ready {
			LOG.Printf("Listener terminated, shutting down...\n")
			break
		}
		rlen, src, err := s.conn.ReadFromUDP(buf)
		if err != nil {
			m := fmt.Sprintf("Received garbage from %v\n", src)
			s.maybe_log(m)
			continue
		}
		// Non-blocking reads hit this condition all the time
		if rlen == 0 {
			continue
		}
		if rlen > MAX_DGRAM_SIZE {
			LOG.Printf("Oversized datagram (%dB) from %v\n", rlen, src)
			continue
		}
		msg := strings.TrimSpace(string(buf[0:rlen]))
		messages := s.splitMultiLineMsg(msg)

		// NOTE: we push the whole batch
		err = s.EnqueueMsgs(messages)
		if err != nil {
			em := fmt.Sprintf("Msg push failed: %v", err)
			WarnAndSuppress(em)
			s.times_waited_on_write++
			s.total_wait_us += 1000
			time.Sleep(1 * time.Millisecond)
			continue
		}
	}
}

func (s *StatsdListener) store_own_stats() {
	written := new(StatsdMessage)
	written.Name = "packets_enqueued"
	written.Type = STATSD_TYPE_COUNT
	written.Value = float64(s.write_pos)

	t := make(map[string]string)
	t["type"] = "internal"
	written.Tags = t

	s.processMessage(written)
}

func (s *StatsdListener) keepFlipping() {
	for {
		time.Sleep(15 * time.Second)
		s.store_own_stats()
		s.collector.Flip(s.verbose)
	}
}

func (s *StatsdListener) splitMultiLineMsg(msg string) []string {
	items := strings.Split(msg, "\n")
	out := make([]string, 0)
	for _, item := range items {
		if len(item) < 3 {
			continue
		}
		out = append(out, item)
	}
	return out
}

func (s *StatsdListener) Stats() {
	LOG.Printf("reader=%d, writer=%d, ready=%v\n", s.read_pos, s.write_pos, s.ready)
	LOG.Printf("times_waited: read=%d, write=%d; total_wait(ms)=%.01f\n",
		s.times_waited_on_read, s.times_waited_on_write,
		float64(s.total_wait_us)/1000.0)
}

func (s *StatsdListener) maybe_log(msg string) {
	if s.verbose {
		LOG.Print(msg)
	}
}

func main() {
	args := parse_args()
	l := NewListener(uint16(args.UdpListenPort))
	l.verbose = args.Verbose // do we want to be noisy?
	l.Start()
	// We start the HTTP listener separately
	l.collector.Start(uint16(args.PrometheusPort))

	var n uint64 = 0
	for {
		time.Sleep(5 * time.Millisecond)
		n++
		if n%500 == 0 {
			if args.Verbose {
				l.Stats()
			}
		}
	}
}
