import (
	"context"
	"fmt"
	"time"
 
	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/service/athena"
	"github.com/aws/aws-sdk-go-v2/service/athena/types"
)
 
type awsClient struct {
	athenaClient *athena.Client
}
 
func NewAwsClient(ctx context.Context) (*awsClient, error) {
	cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion("ap-northeast-1"))
	if err != nil {
		return nil, err
	}
	athenaClient := athena.NewFromConfig(cfg)
	return &awsClient{
		athenaClient: athenaClient,
	}, nil
}
 
type rowData map[string]string
 
func (ac *awsClient) Query(ctx context.Context, query string, outputLocation string) ([]rowData, error) {
	input := &athena.StartQueryExecutionInput{
		QueryString: aws.String(query),
		ResultConfiguration: &types.ResultConfiguration{
			OutputLocation: aws.String(outputLocation),
		},
	}
 
	output, err := ac.athenaClient.StartQueryExecution(ctx, input)
	if err != nil {
		return nil, err
	}
 
	queryExecutionId := *output.QueryExecutionId
 
	err = ac.waitForQueryToComplete(ctx, queryExecutionId)
	if err != nil {
		return nil, err
	}
 
	rowData, err := ac.processResultRows(ctx, queryExecutionId)
	if err != nil {
		return nil, err
	}
 
	return rowData, nil
}
func (ac *awsClient) waitForQueryToComplete(ctx context.Context, queryExecutionId string) error {
	input := &athena.GetQueryExecutionInput{
		QueryExecutionId: aws.String(queryExecutionId),
	}
 
	runCount := 0
	maxRunCount := 10
	sleepSec := 10
	for runCount < maxRunCount {
		output, err := ac.athenaClient.GetQueryExecution(ctx, input)
		if err != nil {
			return err
		}
 
		switch output.QueryExecution.Status.State {
		case types.QueryExecutionStateSucceeded:
			return nil
		case types.QueryExecutionStateFailed:
			return fmt.Errorf("athena query failed to run with error message: %s", *output.QueryExecution.Status.StateChangeReason)
		case types.QueryExecutionStateCancelled:
			return fmt.Errorf("athena query was cancelled")
		default:
			time.Sleep(time.Duration(sleepSec) * time.Second)
			runCount++
		}
	}
 
	return fmt.Errorf("athena query was timeout")
}
 
func (ac *awsClient) processResultRows(ctx context.Context, queryExecutionId string) ([]rowData, error) {
	input := &athena.GetQueryResultsInput{
		QueryExecutionId: aws.String(queryExecutionId),
	}
 
	rds := make([]rowData, 0)
	rns := make([]string, 0)
 
	paginator := athena.NewGetQueryResultsPaginator(ac.athenaClient, input)
	first := true
	for paginator.HasMorePages() {
		output, err := paginator.NextPage(ctx)
		if err != nil {
			return nil, err
		}
 
		if first {
			for _, meta := range output.ResultSet.ResultSetMetadata.ColumnInfo {
				rns = append(rns, *meta.Name)
			}
		}
 
		for _, v := range output.ResultSet.Rows {
			if first {
				// Ignore first row of first run. It's header row
				first = false
				continue
			}
			rd := rowData{}
			for i, d := range v.Data {
				rd[rns[i]] = *d.VarCharValue
			}
			rds = append(rds, rd)
		}
	}
 
	return rds, nil
}