video-iac/scripts/gitea-scaler/main.go

375 lines
8.5 KiB
Go

package main
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/redis/go-redis/v9"
"google.golang.org/grpc"
pb "gitea-scaler/externalscaler"
)
const (
metricName = "gitea_waiting_jobs"
jobKeyPrefix = "gitea-scaler:job:"
labelCountPrefix = "gitea-scaler:label:"
)
type workflowJobPayload struct {
Action string `json:"action"`
WorkflowJob struct {
ID int64 `json:"id"`
Status string `json:"status"`
Labels []string `json:"labels"`
} `json:"workflow_job"`
}
type scalerServer struct {
pb.UnimplementedExternalScalerServer
redis *redis.Client
}
func main() {
redisClient := redis.NewClient(&redis.Options{
Addr: getEnv("REDIS_ADDR", "redis.jam-cloud-infra.svc.cluster.local:6379"),
DB: mustAtoi(getEnv("REDIS_DB", "1")),
})
if err := redisClient.Ping(context.Background()).Err(); err != nil {
log.Fatalf("failed to connect to redis: %v", err)
}
server := &scalerServer{redis: redisClient}
go serveGRPC(server)
go serveHTTP(server)
select {}
}
func serveGRPC(server *scalerServer) {
grpcServer := grpc.NewServer()
pb.RegisterExternalScalerServer(grpcServer, server)
lis, err := net.Listen("tcp", getEnv("GRPC_ADDR", ":50051"))
if err != nil {
log.Fatalf("failed to listen for grpc: %v", err)
}
log.Printf("gRPC scaler listening on %s", lis.Addr())
if err := grpcServer.Serve(lis); err != nil {
log.Fatalf("failed to serve grpc: %v", err)
}
}
func serveHTTP(server *scalerServer) {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
mux.HandleFunc("/webhooks/workflow-job", server.handleWorkflowJob)
addr := getEnv("HTTP_ADDR", ":8080")
log.Printf("webhook server listening on %s", addr)
if err := http.ListenAndServe(addr, mux); err != nil {
log.Fatalf("failed to serve http: %v", err)
}
}
func (s *scalerServer) handleWorkflowJob(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
event := r.Header.Get("X-Gitea-Event")
if event == "" {
event = r.Header.Get("X-GitHub-Event")
}
if event != "workflow_job" {
http.Error(w, "unsupported event", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read request body", http.StatusBadRequest)
return
}
if err := verifySignature(body, r); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
var payload workflowJobPayload
if err := json.Unmarshal(body, &payload); err != nil {
http.Error(w, "invalid json payload", http.StatusBadRequest)
return
}
if payload.WorkflowJob.ID == 0 {
http.Error(w, "missing workflow_job.id", http.StatusBadRequest)
return
}
switch normalizeAction(payload) {
case "queued":
if err := s.markQueued(r.Context(), payload.WorkflowJob.ID, payload.WorkflowJob.Labels); err != nil {
http.Error(w, "failed to record queued job", http.StatusInternalServerError)
return
}
case "remove":
if err := s.markComplete(r.Context(), payload.WorkflowJob.ID); err != nil {
http.Error(w, "failed to remove completed job", http.StatusInternalServerError)
return
}
default:
log.Printf("ignoring workflow_job action=%q status=%q id=%d", payload.Action, payload.WorkflowJob.Status, payload.WorkflowJob.ID)
}
w.WriteHeader(http.StatusAccepted)
}
func (s *scalerServer) IsActive(ctx context.Context, ref *pb.ScaledObjectRef) (*pb.IsActiveResponse, error) {
count, err := s.countForLabels(ctx, labelsFromRef(ref))
if err != nil {
return nil, err
}
return &pb.IsActiveResponse{Result: count > 0}, nil
}
func (s *scalerServer) StreamIsActive(ref *pb.ScaledObjectRef, stream pb.ExternalScaler_StreamIsActiveServer) error {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
count, err := s.countForLabels(stream.Context(), labelsFromRef(ref))
if err != nil {
return err
}
if err := stream.Send(&pb.IsActiveResponse{Result: count > 0}); err != nil {
return err
}
select {
case <-stream.Context().Done():
return nil
case <-ticker.C:
}
}
}
func (s *scalerServer) GetMetricSpec(context.Context, *pb.ScaledObjectRef) (*pb.GetMetricSpecResponse, error) {
return &pb.GetMetricSpecResponse{
MetricSpecs: []*pb.MetricSpec{
{
MetricName: metricName,
TargetSize: 1,
},
},
}, nil
}
func (s *scalerServer) GetMetrics(ctx context.Context, req *pb.GetMetricsRequest) (*pb.GetMetricsResponse, error) {
count, err := s.countForLabels(ctx, labelsFromRef(req.GetScaledObjectRef()))
if err != nil {
return nil, err
}
return &pb.GetMetricsResponse{
MetricValues: []*pb.MetricValue{
{
MetricName: metricName,
MetricValue: int64(count),
},
},
}, nil
}
func (s *scalerServer) markQueued(ctx context.Context, jobID int64, labels []string) error {
jobKey := fmt.Sprintf("%s%d", jobKeyPrefix, jobID)
existingLabels, err := s.redis.SMembers(ctx, jobKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
pipe := s.redis.TxPipeline()
if len(existingLabels) > 0 {
for _, label := range existingLabels {
pipe.Decr(ctx, labelKey(label))
}
pipe.Del(ctx, jobKey)
}
if len(labels) > 0 {
members := make([]interface{}, 0, len(labels))
for _, label := range labels {
normalized := strings.TrimSpace(label)
if normalized == "" {
continue
}
members = append(members, normalized)
pipe.Incr(ctx, labelKey(normalized))
}
if len(members) > 0 {
pipe.SAdd(ctx, jobKey, members...)
}
}
_, err = pipe.Exec(ctx)
return err
}
func (s *scalerServer) markComplete(ctx context.Context, jobID int64) error {
jobKey := fmt.Sprintf("%s%d", jobKeyPrefix, jobID)
labels, err := s.redis.SMembers(ctx, jobKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
if len(labels) == 0 {
return nil
}
pipe := s.redis.TxPipeline()
for _, label := range labels {
pipe.Decr(ctx, labelKey(label))
}
pipe.Del(ctx, jobKey)
_, err = pipe.Exec(ctx)
return err
}
func (s *scalerServer) countForLabels(ctx context.Context, labels []string) (int, error) {
if len(labels) == 0 {
return 0, nil
}
keys := make([]string, 0, len(labels))
for _, label := range labels {
keys = append(keys, labelKey(label))
}
values, err := s.redis.MGet(ctx, keys...).Result()
if err != nil {
return 0, err
}
total := 0
for _, value := range values {
switch v := value.(type) {
case nil:
case string:
n, err := strconv.Atoi(v)
if err != nil {
return 0, err
}
total += max(n, 0)
case int64:
total += max(int(v), 0)
default:
return 0, fmt.Errorf("unexpected redis value type %T", value)
}
}
return total, nil
}
func verifySignature(body []byte, r *http.Request) error {
secret := os.Getenv("WEBHOOK_SECRET")
if secret == "" {
return nil
}
signature := r.Header.Get("X-Gitea-Signature")
if signature == "" {
signature = r.Header.Get("X-Hub-Signature-256")
signature = strings.TrimPrefix(signature, "sha256=")
}
if signature == "" {
return errors.New("missing webhook signature header")
}
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(body)
expected := hex.EncodeToString(mac.Sum(nil))
if !hmac.Equal([]byte(expected), []byte(strings.ToLower(signature))) {
return errors.New("invalid webhook signature")
}
return nil
}
func labelsFromRef(ref *pb.ScaledObjectRef) []string {
if ref == nil {
return nil
}
labels := strings.Split(ref.ScalerMetadata["labels"], ",")
result := make([]string, 0, len(labels))
for _, label := range labels {
normalized := strings.TrimSpace(label)
if normalized != "" {
result = append(result, normalized)
}
}
return result
}
func normalizeAction(payload workflowJobPayload) string {
action := strings.ToLower(strings.TrimSpace(payload.Action))
status := strings.ToLower(strings.TrimSpace(payload.WorkflowJob.Status))
switch action {
case "queued":
return "queued"
case "in_progress", "completed":
return "remove"
}
switch status {
case "queued":
return "queued"
case "in_progress", "completed":
return "remove"
default:
return ""
}
}
func labelKey(label string) string {
return fmt.Sprintf("%s%s", labelCountPrefix, label)
}
func getEnv(key, fallback string) string {
if value := strings.TrimSpace(os.Getenv(key)); value != "" {
return value
}
return fallback
}
func mustAtoi(value string) int {
n, err := strconv.Atoi(value)
if err != nil {
log.Fatalf("invalid integer %q: %v", value, err)
}
return n
}