タイトルまんまなのでさっそくサンプルコード

追記: aws-sdk-go-v2 Athenaにクエリを投げて結果をページングで受け取る にほぼ同じことを書いていた。backoffがあるかどうかくらい

サンプルコード

import (
	"context"
	"fmt"
	"math"
	"math/rand"
	"time"
 
	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/service/athena"
	"github.com/aws/aws-sdk-go-v2/service/athena/types"
)
 
const (
	baseDelay = time.Second
	maxDelay  = 30 * time.Second
)
 
type awsClient struct {
	athena *athena.Client
}
 
func NewAwsClient(ctx context.Context, region string) (*awsClient, error) {
	cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
	if err != nil {
		return nil, err
	}
	athenaClient := athena.NewFromConfig(cfg)
	return &awsClient{
		athena: athenaClient,
	}, nil
}
 
type rowData map[string]string
 
func (ac *awsClient) Search(ctx context.Context, s3Bucket, s3Key string, query string, timeout time.Duration) (*string, error) {
	queryExecutionId, err := ac.startQuery(ctx, s3Bucket, s3Key, query)
	if err != nil {
		return nil, err
	}
 
	err = ac.waitForQueryToComplete(ctx, queryExecutionId, timeout)
	if err != nil {
		return nil, err
	}
 
	return queryExecutionId, nil
}
 
func (ac *awsClient) startQuery(ctx context.Context, s3Bucket, s3Key string, query string) (*string, error) {
	outputLocation := fmt.Sprintf("s3://%s/%s/", s3Bucket, s3Key)
 
	input := &athena.StartQueryExecutionInput{
		QueryString: aws.String(query),
		ResultConfiguration: &types.ResultConfiguration{
			OutputLocation: aws.String(outputLocation),
		},
	}
 
	output, err := ac.athena.StartQueryExecution(ctx, input)
	if err != nil {
		return nil, err
	}
 
	return output.QueryExecutionId, nil
}
 
func (ac *awsClient) waitForQueryToComplete(ctx context.Context, queryExecutionId *string, timeout time.Duration) error {
	input := &athena.GetQueryExecutionInput{
		QueryExecutionId: queryExecutionId,
	}
 
	startTime := time.Now()
	expire := startTime.Add(timeout)
	count := 0
	for {
		output, err := ac.athena.GetQueryExecution(ctx, input)
		if err != nil {
			return err
		}
 
		switch output.QueryExecution.Status.State {
		case types.QueryExecutionStateSucceeded:
			// success
			return nil
		case types.QueryExecutionStateFailed:
			return fmt.Errorf("athena query failed to run with error message: %s", *output.QueryExecution.Status.StateChangeReason)
		case types.QueryExecutionStateCancelled:
			return fmt.Errorf("athena query was cancelled")
		default:
			if time.Now().After(expire) {
				_, err := ac.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{
					QueryExecutionId: queryExecutionId,
				})
				if err != nil {
					return fmt.Errorf("cannot stop athena query")
				} else {
					return fmt.Errorf("athena query was timeout")
				}
			}
 
			delay := backoff(count, baseDelay, maxDelay)
			time.Sleep(delay)
			count++
		}
	}
}
 
func (ac *awsClient) processResultRows(ctx context.Context, queryExecutionId string) ([]rowData, error) {
	input := &athena.GetQueryResultsInput{
		QueryExecutionId: aws.String(queryExecutionId),
	}
 
	rds := make([]rowData, 0)
	rns := make([]string, 0)
 
	paginator := athena.NewGetQueryResultsPaginator(ac.athena, input)
	first := true
	for paginator.HasMorePages() {
		output, err := paginator.NextPage(ctx)
		if err != nil {
			return nil, err
		}
 
		if first {
			for _, meta := range output.ResultSet.ResultSetMetadata.ColumnInfo {
				rns = append(rns, *meta.Name)
			}
		}
 
		for _, v := range output.ResultSet.Rows {
			if first {
				// Ignore first row of first run. It's header row
				first = false
				continue
			}
			rd := rowData{}
			for i, d := range v.Data {
				rd[rns[i]] = *d.VarCharValue
			}
			rds = append(rds, rd)
		}
	}
 
	return rds, nil
}
 
func backoff(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
	maxf := float64(maxDelay)
	basef := float64(baseDelay)
 
	durf := basef * math.Pow(2, float64(attempt))
	durf = rand.Float64()*(durf-basef) + basef
 
	if durf > maxf {
		durf = maxf
	}
 
	return time.Duration(durf)
}

ポイント

  • クエリを開始する
  • クエリの結果を待つ
  • 結果を取得する
    • athena.GetQueryResultsPaginator を使いページネーションを簡単に実装できる