import ( "context" "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/athena" "github.com/aws/aws-sdk-go-v2/service/athena/types" ) type awsClient struct { athenaClient *athena.Client } func NewAwsClient(ctx context.Context) (*awsClient, error) { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion("ap-northeast-1")) if err != nil { return nil, err } athenaClient := athena.NewFromConfig(cfg) return &awsClient{ athenaClient: athenaClient, }, nil } type rowData map[string]string func (ac *awsClient) Query(ctx context.Context, query string, outputLocation string) ([]rowData, error) { input := &athena.StartQueryExecutionInput{ QueryString: aws.String(query), ResultConfiguration: &types.ResultConfiguration{ OutputLocation: aws.String(outputLocation), }, } output, err := ac.athenaClient.StartQueryExecution(ctx, input) if err != nil { return nil, err } queryExecutionId := *output.QueryExecutionId err = ac.waitForQueryToComplete(ctx, queryExecutionId) if err != nil { return nil, err } rowData, err := ac.processResultRows(ctx, queryExecutionId) if err != nil { return nil, err } return rowData, nil } func (ac *awsClient) waitForQueryToComplete(ctx context.Context, queryExecutionId string) error { input := &athena.GetQueryExecutionInput{ QueryExecutionId: aws.String(queryExecutionId), } runCount := 0 maxRunCount := 10 sleepSec := 10 for runCount < maxRunCount { output, err := ac.athenaClient.GetQueryExecution(ctx, input) if err != nil { return err } switch output.QueryExecution.Status.State { case types.QueryExecutionStateSucceeded: return nil case types.QueryExecutionStateFailed: return fmt.Errorf("athena query failed to run with error message: %s", *output.QueryExecution.Status.StateChangeReason) case types.QueryExecutionStateCancelled: return fmt.Errorf("athena query was cancelled") default: time.Sleep(time.Duration(sleepSec) * time.Second) runCount++ } } return fmt.Errorf("athena query was timeout") } func (ac *awsClient) processResultRows(ctx context.Context, queryExecutionId string) ([]rowData, error) { input := &athena.GetQueryResultsInput{ QueryExecutionId: aws.String(queryExecutionId), } rds := make([]rowData, 0) rns := make([]string, 0) paginator := athena.NewGetQueryResultsPaginator(ac.athenaClient, input) first := true for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { return nil, err } if first { for _, meta := range output.ResultSet.ResultSetMetadata.ColumnInfo { rns = append(rns, *meta.Name) } } for _, v := range output.ResultSet.Rows { if first { // Ignore first row of first run. It's header row first = false continue } rd := rowData{} for i, d := range v.Data { rd[rns[i]] = *d.VarCharValue } rds = append(rds, rd) } } return rds, nil }