375 lines
8.5 KiB
Go
375 lines
8.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
"google.golang.org/grpc"
|
|
pb "gitea-scaler/externalscaler"
|
|
)
|
|
|
|
const (
|
|
metricName = "gitea_waiting_jobs"
|
|
jobKeyPrefix = "gitea-scaler:job:"
|
|
labelCountPrefix = "gitea-scaler:label:"
|
|
)
|
|
|
|
type workflowJobPayload struct {
|
|
Action string `json:"action"`
|
|
WorkflowJob struct {
|
|
ID int64 `json:"id"`
|
|
Status string `json:"status"`
|
|
Labels []string `json:"labels"`
|
|
} `json:"workflow_job"`
|
|
}
|
|
|
|
type scalerServer struct {
|
|
pb.UnimplementedExternalScalerServer
|
|
redis *redis.Client
|
|
}
|
|
|
|
func main() {
|
|
redisClient := redis.NewClient(&redis.Options{
|
|
Addr: getEnv("REDIS_ADDR", "redis.jam-cloud-infra.svc.cluster.local:6379"),
|
|
DB: mustAtoi(getEnv("REDIS_DB", "1")),
|
|
})
|
|
|
|
if err := redisClient.Ping(context.Background()).Err(); err != nil {
|
|
log.Fatalf("failed to connect to redis: %v", err)
|
|
}
|
|
|
|
server := &scalerServer{redis: redisClient}
|
|
|
|
go serveGRPC(server)
|
|
go serveHTTP(server)
|
|
|
|
select {}
|
|
}
|
|
|
|
func serveGRPC(server *scalerServer) {
|
|
grpcServer := grpc.NewServer()
|
|
pb.RegisterExternalScalerServer(grpcServer, server)
|
|
|
|
lis, err := net.Listen("tcp", getEnv("GRPC_ADDR", ":50051"))
|
|
if err != nil {
|
|
log.Fatalf("failed to listen for grpc: %v", err)
|
|
}
|
|
|
|
log.Printf("gRPC scaler listening on %s", lis.Addr())
|
|
if err := grpcServer.Serve(lis); err != nil {
|
|
log.Fatalf("failed to serve grpc: %v", err)
|
|
}
|
|
}
|
|
|
|
func serveHTTP(server *scalerServer) {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
})
|
|
mux.HandleFunc("/webhooks/workflow-job", server.handleWorkflowJob)
|
|
|
|
addr := getEnv("HTTP_ADDR", ":8080")
|
|
log.Printf("webhook server listening on %s", addr)
|
|
if err := http.ListenAndServe(addr, mux); err != nil {
|
|
log.Fatalf("failed to serve http: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *scalerServer) handleWorkflowJob(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
event := r.Header.Get("X-Gitea-Event")
|
|
if event == "" {
|
|
event = r.Header.Get("X-GitHub-Event")
|
|
}
|
|
if event != "workflow_job" {
|
|
http.Error(w, "unsupported event", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "failed to read request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if err := verifySignature(body, r); err != nil {
|
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
var payload workflowJobPayload
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
http.Error(w, "invalid json payload", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if payload.WorkflowJob.ID == 0 {
|
|
http.Error(w, "missing workflow_job.id", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
switch normalizeAction(payload) {
|
|
case "queued":
|
|
if err := s.markQueued(r.Context(), payload.WorkflowJob.ID, payload.WorkflowJob.Labels); err != nil {
|
|
http.Error(w, "failed to record queued job", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
case "remove":
|
|
if err := s.markComplete(r.Context(), payload.WorkflowJob.ID); err != nil {
|
|
http.Error(w, "failed to remove completed job", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
default:
|
|
log.Printf("ignoring workflow_job action=%q status=%q id=%d", payload.Action, payload.WorkflowJob.Status, payload.WorkflowJob.ID)
|
|
}
|
|
|
|
w.WriteHeader(http.StatusAccepted)
|
|
}
|
|
|
|
func (s *scalerServer) IsActive(ctx context.Context, ref *pb.ScaledObjectRef) (*pb.IsActiveResponse, error) {
|
|
count, err := s.countForLabels(ctx, labelsFromRef(ref))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &pb.IsActiveResponse{Result: count > 0}, nil
|
|
}
|
|
|
|
func (s *scalerServer) StreamIsActive(ref *pb.ScaledObjectRef, stream pb.ExternalScaler_StreamIsActiveServer) error {
|
|
ticker := time.NewTicker(15 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
count, err := s.countForLabels(stream.Context(), labelsFromRef(ref))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := stream.Send(&pb.IsActiveResponse{Result: count > 0}); err != nil {
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case <-stream.Context().Done():
|
|
return nil
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *scalerServer) GetMetricSpec(context.Context, *pb.ScaledObjectRef) (*pb.GetMetricSpecResponse, error) {
|
|
return &pb.GetMetricSpecResponse{
|
|
MetricSpecs: []*pb.MetricSpec{
|
|
{
|
|
MetricName: metricName,
|
|
TargetSize: 1,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *scalerServer) GetMetrics(ctx context.Context, req *pb.GetMetricsRequest) (*pb.GetMetricsResponse, error) {
|
|
count, err := s.countForLabels(ctx, labelsFromRef(req.GetScaledObjectRef()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &pb.GetMetricsResponse{
|
|
MetricValues: []*pb.MetricValue{
|
|
{
|
|
MetricName: metricName,
|
|
MetricValue: int64(count),
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *scalerServer) markQueued(ctx context.Context, jobID int64, labels []string) error {
|
|
jobKey := fmt.Sprintf("%s%d", jobKeyPrefix, jobID)
|
|
|
|
existingLabels, err := s.redis.SMembers(ctx, jobKey).Result()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
return err
|
|
}
|
|
|
|
pipe := s.redis.TxPipeline()
|
|
|
|
if len(existingLabels) > 0 {
|
|
for _, label := range existingLabels {
|
|
pipe.Decr(ctx, labelKey(label))
|
|
}
|
|
pipe.Del(ctx, jobKey)
|
|
}
|
|
|
|
if len(labels) > 0 {
|
|
members := make([]interface{}, 0, len(labels))
|
|
for _, label := range labels {
|
|
normalized := strings.TrimSpace(label)
|
|
if normalized == "" {
|
|
continue
|
|
}
|
|
members = append(members, normalized)
|
|
pipe.Incr(ctx, labelKey(normalized))
|
|
}
|
|
if len(members) > 0 {
|
|
pipe.SAdd(ctx, jobKey, members...)
|
|
}
|
|
}
|
|
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (s *scalerServer) markComplete(ctx context.Context, jobID int64) error {
|
|
jobKey := fmt.Sprintf("%s%d", jobKeyPrefix, jobID)
|
|
labels, err := s.redis.SMembers(ctx, jobKey).Result()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
return err
|
|
}
|
|
|
|
if len(labels) == 0 {
|
|
return nil
|
|
}
|
|
|
|
pipe := s.redis.TxPipeline()
|
|
for _, label := range labels {
|
|
pipe.Decr(ctx, labelKey(label))
|
|
}
|
|
pipe.Del(ctx, jobKey)
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (s *scalerServer) countForLabels(ctx context.Context, labels []string) (int, error) {
|
|
if len(labels) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
keys := make([]string, 0, len(labels))
|
|
for _, label := range labels {
|
|
keys = append(keys, labelKey(label))
|
|
}
|
|
|
|
values, err := s.redis.MGet(ctx, keys...).Result()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
total := 0
|
|
for _, value := range values {
|
|
switch v := value.(type) {
|
|
case nil:
|
|
case string:
|
|
n, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total += max(n, 0)
|
|
case int64:
|
|
total += max(int(v), 0)
|
|
default:
|
|
return 0, fmt.Errorf("unexpected redis value type %T", value)
|
|
}
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
func verifySignature(body []byte, r *http.Request) error {
|
|
secret := os.Getenv("WEBHOOK_SECRET")
|
|
if secret == "" {
|
|
return nil
|
|
}
|
|
|
|
signature := r.Header.Get("X-Gitea-Signature")
|
|
if signature == "" {
|
|
signature = r.Header.Get("X-Hub-Signature-256")
|
|
signature = strings.TrimPrefix(signature, "sha256=")
|
|
}
|
|
if signature == "" {
|
|
return errors.New("missing webhook signature header")
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
|
mac.Write(body)
|
|
expected := hex.EncodeToString(mac.Sum(nil))
|
|
|
|
if !hmac.Equal([]byte(expected), []byte(strings.ToLower(signature))) {
|
|
return errors.New("invalid webhook signature")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func labelsFromRef(ref *pb.ScaledObjectRef) []string {
|
|
if ref == nil {
|
|
return nil
|
|
}
|
|
labels := strings.Split(ref.ScalerMetadata["labels"], ",")
|
|
result := make([]string, 0, len(labels))
|
|
for _, label := range labels {
|
|
normalized := strings.TrimSpace(label)
|
|
if normalized != "" {
|
|
result = append(result, normalized)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func normalizeAction(payload workflowJobPayload) string {
|
|
action := strings.ToLower(strings.TrimSpace(payload.Action))
|
|
status := strings.ToLower(strings.TrimSpace(payload.WorkflowJob.Status))
|
|
|
|
switch action {
|
|
case "queued":
|
|
return "queued"
|
|
case "in_progress", "completed":
|
|
return "remove"
|
|
}
|
|
|
|
switch status {
|
|
case "queued":
|
|
return "queued"
|
|
case "in_progress", "completed":
|
|
return "remove"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func labelKey(label string) string {
|
|
return fmt.Sprintf("%s%s", labelCountPrefix, label)
|
|
}
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if value := strings.TrimSpace(os.Getenv(key)); value != "" {
|
|
return value
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func mustAtoi(value string) int {
|
|
n, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
log.Fatalf("invalid integer %q: %v", value, err)
|
|
}
|
|
return n
|
|
}
|
|
|