Files
runbin/internal/repository/database_store.go
2025-04-14 22:57:36 +08:00

175 lines
3.8 KiB
Go

package repository
import (
"context"
"database/sql"
"fmt"
"runbin/internal/model"
"time"
_ "github.com/lib/pq"
)
type PostgresStore struct {
db *sql.DB
}
func NewPostgresStore(connStr string) (*PostgresStore, error) {
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("database ping failed: %w", err)
}
return &PostgresStore{db: db}, nil
}
func (s *PostgresStore) Save(p *model.Paste) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := s.db.ExecContext(ctx,
`INSERT INTO pastes (
id, code, created_at, status,
language, stdin, stdout, stderr,
execution_time_ms, memory_usage_kb, updated_at, backend,
compile_log
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)`,
p.ID, p.Code, p.CreatedAt, p.Status,
p.Language, p.Stdin, p.Stdout, p.Stderr,
p.ExecutionTimeMs, p.MemoryUsageKb, p.UpdatedAt, p.BackEnd, p.CompileLog)
return err
}
func (s *PostgresStore) GetByID(id string) (*model.Paste, bool) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
var p model.Paste
err := s.db.QueryRowContext(ctx,
`SELECT
id, code, created_at, status,
language, stdin, stdout, stderr,
execution_time_ms, memory_usage_kb, updated_at, backend,
compile_log
FROM pastes WHERE id = $1`, id).Scan(
&p.ID,
&p.Code,
&p.CreatedAt,
&p.Status,
&p.Language,
&p.Stdin,
&p.Stdout,
&p.Stderr,
&p.ExecutionTimeMs,
&p.MemoryUsageKb,
&p.UpdatedAt,
&p.BackEnd,
&p.CompileLog)
if err != nil {
if err == sql.ErrNoRows {
return nil, false
}
return nil, false
}
return &p, true
}
func (s *PostgresStore) Close() error {
return s.db.Close()
}
func (s *PostgresStore) DispatchExecutionTask(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := s.db.ExecContext(ctx,
`INSERT INTO queue (id) VALUES ($1)`,
id)
return err
}
func (s *PostgresStore) GetTask(ctx context.Context) (*model.Paste, error) {
tx, err := s.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
// 原子性地删除并获取队列中最旧的任务ID
var taskID string
err = tx.QueryRowContext(ctx,
`DELETE FROM queue
WHERE ctid = (
SELECT ctid FROM queue
ORDER BY created_at
FOR UPDATE SKIP LOCKED
LIMIT 1
)
RETURNING id`).Scan(&taskID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // 没有任务时返回nil
}
return nil, fmt.Errorf("failed to get task: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
// 获取完整的任务数据
p, ok := s.GetByID(taskID)
if !ok {
return nil, fmt.Errorf("failed to get task details: %w", err)
}
return p, nil
}
func (s *PostgresStore) Update(p *model.Paste) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// Always update the UpdatedAt timestamp on an update operation
p.UpdatedAt = time.Now()
_, err := s.db.ExecContext(ctx,
`UPDATE pastes SET
status = $1,
stdout = $2,
stderr = $3,
execution_time_ms = $4,
memory_usage_kb = $5,
updated_at = $6,
backend = $7,
compile_log = $8
WHERE id = $9; `,
p.Status,
p.Stdout,
p.Stderr,
p.ExecutionTimeMs,
p.MemoryUsageKb,
p.UpdatedAt,
p.BackEnd,
p.CompileLog,
p.ID,
)
if err != nil {
return fmt.Errorf("failed to execute update for paste with id %s: %w", p.ID, err)
}
return nil
}