package msc2836_test

import (
	"bytes"
	"context"
	"crypto/ed25519"
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"io/ioutil"
	"net/http"
	"sort"
	"strings"
	"testing"
	"time"

	"github.com/gorilla/mux"
	"github.com/matrix-org/dendrite/internal/hooks"
	"github.com/matrix-org/dendrite/internal/httputil"
	roomserver "github.com/matrix-org/dendrite/roomserver/api"
	"github.com/matrix-org/dendrite/setup/base"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/dendrite/setup/mscs/msc2836"
	userapi "github.com/matrix-org/dendrite/userapi/api"
	"github.com/matrix-org/gomatrixserverlib"
)

var (
	client = &http.Client{
		Timeout: 10 * time.Second,
	}
)

// Basic sanity check of MSC2836 logic. Injects a thread that looks like:
//   A
//   |
//   B
//  / \
// C   D
//    /|\
//   E F G
//   |
//   H
// And makes sure POST /event_relationships works with various parameters
func TestMSC2836(t *testing.T) {
	alice := "@alice:localhost"
	bob := "@bob:localhost"
	charlie := "@charlie:localhost"
	roomID := "!alice:localhost"
	// give access tokens to all three users
	nopUserAPI := &testUserAPI{
		accessTokens: make(map[string]userapi.Device),
	}
	nopUserAPI.accessTokens["alice"] = userapi.Device{
		AccessToken: "alice",
		DisplayName: "Alice",
		UserID:      alice,
	}
	nopUserAPI.accessTokens["bob"] = userapi.Device{
		AccessToken: "bob",
		DisplayName: "Bob",
		UserID:      bob,
	}
	nopUserAPI.accessTokens["charlie"] = userapi.Device{
		AccessToken: "charlie",
		DisplayName: "Charles",
		UserID:      charlie,
	}
	eventA := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: alice,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[A] Do you know shelties?",
		},
	})
	eventB := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: bob,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[B] I <3 shelties",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventA.EventID(),
			},
		},
	})
	eventC := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: bob,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[C] like so much",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventB.EventID(),
			},
		},
	})
	eventD := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: alice,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[D] but what are shelties???",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventB.EventID(),
			},
		},
	})
	eventE := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: bob,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[E] seriously???",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventD.EventID(),
			},
		},
	})
	eventF := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: charlie,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[F] omg how do you not know what shelties are",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventD.EventID(),
			},
		},
	})
	eventG := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: alice,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[G] looked it up, it's a sheltered person?",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventD.EventID(),
			},
		},
	})
	eventH := mustCreateEvent(t, fledglingEvent{
		RoomID: roomID,
		Sender: bob,
		Type:   "m.room.message",
		Content: map[string]interface{}{
			"body": "[H] it's a dog!!!!!",
			"m.relationship": map[string]string{
				"rel_type": "m.reference",
				"event_id": eventE.EventID(),
			},
		},
	})
	// make everyone joined to each other's rooms
	nopRsAPI := &testRoomserverAPI{
		userToJoinedRooms: map[string][]string{
			alice:   []string{roomID},
			bob:     []string{roomID},
			charlie: []string{roomID},
		},
		events: map[string]*gomatrixserverlib.HeaderedEvent{
			eventA.EventID(): eventA,
			eventB.EventID(): eventB,
			eventC.EventID(): eventC,
			eventD.EventID(): eventD,
			eventE.EventID(): eventE,
			eventF.EventID(): eventF,
			eventG.EventID(): eventG,
			eventH.EventID(): eventH,
		},
	}
	router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{
		eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH,
	})
	cancel := runServer(t, router)
	defer cancel()

	t.Run("returns 403 on invalid event IDs", func(t *testing.T) {
		_ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{
			"event_id": "$invalid",
		}))
	})
	t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) {
		nopUserAPI.accessTokens["frank"] = userapi.Device{
			AccessToken: "frank",
			DisplayName: "Frank Not In Room",
			UserID:      "@frank:localhost",
		}
		_ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{
			"event_id":       eventB.EventID(),
			"limit":          1,
			"include_parent": true,
		}))
	})
	t.Run("returns the parent if include_parent is true", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":       eventB.EventID(),
			"include_parent": true,
			"limit":          2,
		}))
		assertContains(t, body, []string{eventB.EventID(), eventA.EventID()})
	})
	t.Run("returns the children in the right order if include_children is true", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":         eventD.EventID(),
			"include_children": true,
			"recent_first":     true,
			"limit":            4,
		}))
		assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
		body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":         eventD.EventID(),
			"include_children": true,
			"recent_first":     false,
			"limit":            4,
		}))
		assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
	})
	t.Run("walks the graph depth first", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  true,
			"limit":        6,
		}))
		// Oldest first so:
		//   A
		//   |
		//   B1
		//  / \
		// C2  D3
		//    /| \
		//  4E 6F G
		//   |
		//  5H
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()})
		body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": true,
			"depth_first":  true,
			"limit":        6,
		}))
		// Recent first so:
		//   A
		//   |
		//   B1
		//  / \
		// C   D2
		//    /| \
		//  E5 F4 G3
		//   |
		//  H6
		assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()})
	})
	t.Run("walks the graph breadth first", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  false,
			"limit":        6,
		}))
		// Oldest first so:
		//   A
		//   |
		//   B1
		//  / \
		// C2  D3
		//    /| \
		//  E4 F5 G6
		//   |
		//   H
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
		body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": true,
			"depth_first":  false,
			"limit":        6,
		}))
		// Recent first so:
		//   A
		//   |
		//   B1
		//  / \
		// C3  D2
		//    /| \
		//  E6 F5 G4
		//   |
		//   H
		assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
	})
	t.Run("caps via max_breadth", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  false,
			"max_breadth":  2,
			"limit":        10,
		}))
		// Event G gets omitted because of max_breadth
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()})
	})
	t.Run("caps via max_depth", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  false,
			"max_depth":    2,
			"limit":        10,
		}))
		// Event H gets omitted because of max_depth
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
	})
	t.Run("terminates when reaching the limit", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  false,
			"limit":        4,
		}))
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()})
	})
	t.Run("returns all events with a high enough limit", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  false,
			"limit":        400,
		}))
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
	})
	t.Run("can navigate up the graph with direction: up", func(t *testing.T) {
		//   A4
		//   |
		//   B3
		//  / \
		// C   D2
		//    /| \
		//   E F1 G
		//   |
		//   H
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventF.EventID(),
			"recent_first": false,
			"depth_first":  true,
			"direction":    "up",
		}))
		assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
	})
	t.Run("includes children and children_hash in unsigned", func(t *testing.T) {
		body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
			"event_id":     eventB.EventID(),
			"recent_first": false,
			"depth_first":  false,
			"limit":        3,
		}))
		// event B has C,D as children
		// event C has no children
		// event D has 3 children (not included in response)
		assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()})
		assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()})
		assertUnsignedChildren(t, body.Events[1], "", 0, nil)
		assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()})
	})
}

// TODO: TestMSC2836TerminatesLoops (short and long)
// TODO: TestMSC2836UnknownEventsSkipped
// TODO: TestMSC2836SkipEventIfNotInRoom

func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest {
	t.Helper()
	b, err := json.Marshal(jsonBody)
	if err != nil {
		t.Fatalf("Failed to marshal request: %s", err)
	}
	r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b))
	if err != nil {
		t.Fatalf("Failed to NewEventRelationshipRequest: %s", err)
	}
	return r
}

func runServer(t *testing.T, router *mux.Router) func() {
	t.Helper()
	externalServ := &http.Server{
		Addr:         string(":8009"),
		WriteTimeout: 60 * time.Second,
		Handler:      router,
	}
	go func() {
		externalServ.ListenAndServe()
	}()
	// wait to listen on the port
	time.Sleep(500 * time.Millisecond)
	return func() {
		externalServ.Shutdown(context.TODO())
	}
}

func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse {
	t.Helper()
	var r msc2836.EventRelationshipRequest
	r.Defaults()
	data, err := json.Marshal(req)
	if err != nil {
		t.Fatalf("failed to marshal request: %s", err)
	}
	httpReq, err := http.NewRequest(
		"POST", "http://localhost:8009/_matrix/client/unstable/event_relationships",
		bytes.NewBuffer(data),
	)
	httpReq.Header.Set("Authorization", "Bearer "+accessToken)
	if err != nil {
		t.Fatalf("failed to prepare request: %s", err)
	}
	res, err := client.Do(httpReq)
	if err != nil {
		t.Fatalf("failed to do request: %s", err)
	}
	if res.StatusCode != expectCode {
		body, _ := ioutil.ReadAll(res.Body)
		t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
	}
	if res.StatusCode == 200 {
		var result msc2836.EventRelationshipResponse
		body, err := ioutil.ReadAll(res.Body)
		if err != nil {
			t.Fatalf("response 200 OK but failed to read response body: %s", err)
		}
		if err := json.Unmarshal(body, &result); err != nil {
			t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
		}
		return &result
	}
	return nil
}

func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wantEventIDs []string) {
	t.Helper()
	gotEventIDs := make([]string, len(result.Events))
	for i, ev := range result.Events {
		gotEventIDs[i] = ev.EventID
	}
	if len(gotEventIDs) != len(wantEventIDs) {
		t.Fatalf("length mismatch: got %v want %v", gotEventIDs, wantEventIDs)
	}
	for i := range gotEventIDs {
		if gotEventIDs[i] != wantEventIDs[i] {
			t.Errorf("wrong item in position %d - got %s want %s", i, gotEventIDs[i], wantEventIDs[i])
		}
	}
}

func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) {
	t.Helper()
	unsigned := struct {
		Children map[string]int `json:"children"`
		Hash     string         `json:"children_hash"`
	}{}
	if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil {
		if wantCount == 0 {
			return // no children so possible there is no unsigned field at all
		}
		t.Fatalf("Failed to unmarshal unsigned field: %s", err)
	}
	// zero checks
	if wantCount == 0 {
		if len(unsigned.Children) != 0 || unsigned.Hash != "" {
			t.Fatalf("want 0 children but got unsigned fields %+v", unsigned)
		}
		return
	}
	gotCount := unsigned.Children[relType]
	if gotCount != wantCount {
		t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType)
	}
	// work out the hash
	sort.Strings(childrenEventIDs)
	var b strings.Builder
	for _, s := range childrenEventIDs {
		b.WriteString(s)
	}
	t.Logf("hashing %s", b.String())
	hashValBytes := sha256.Sum256([]byte(b.String()))
	wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:])
	if wantHash != unsigned.Hash {
		t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash)
	}
}

type testUserAPI struct {
	userapi.UserInternalAPITrace
	accessTokens map[string]userapi.Device
}

func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
	dev, ok := u.accessTokens[req.AccessToken]
	if !ok {
		res.Err = "unknown token"
		return nil
	}
	res.Device = &dev
	return nil
}

type testRoomserverAPI struct {
	// use a trace API as it implements method stubs so we don't need to have them here.
	// We'll override the functions we care about.
	roomserver.RoomserverInternalAPITrace
	userToJoinedRooms map[string][]string
	events            map[string]*gomatrixserverlib.HeaderedEvent
}

func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
	for _, eventID := range req.EventIDs {
		ev := r.events[eventID]
		if ev != nil {
			res.Events = append(res.Events, ev)
		}
	}
	return nil
}

func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
	rooms := r.userToJoinedRooms[req.UserID]
	for _, roomID := range rooms {
		if roomID == req.RoomID {
			res.IsInRoom = true
			res.HasBeenInRoom = true
			res.Membership = "join"
			break
		}
	}
	return nil
}

func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
	t.Helper()
	cfg := &config.Dendrite{}
	cfg.Defaults(true)
	cfg.Global.ServerName = "localhost"
	cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db"
	cfg.MSCs.MSCs = []string{"msc2836"}
	base := &base.BaseDendrite{
		Cfg:                    cfg,
		PublicClientAPIMux:     mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
		PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
	}

	err := msc2836.Enable(base, rsAPI, nil, userAPI, nil)
	if err != nil {
		t.Fatalf("failed to enable MSC2836: %s", err)
	}
	for _, ev := range events {
		hooks.Run(hooks.KindNewEventPersisted, ev)
	}
	return base.PublicClientAPIMux
}

type fledglingEvent struct {
	Type     string
	StateKey *string
	Content  interface{}
	Sender   string
	RoomID   string
}

func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
	t.Helper()
	roomVer := gomatrixserverlib.RoomVersionV6
	seed := make([]byte, ed25519.SeedSize) // zero seed
	key := ed25519.NewKeyFromSeed(seed)
	eb := gomatrixserverlib.EventBuilder{
		Sender:   ev.Sender,
		Depth:    999,
		Type:     ev.Type,
		StateKey: ev.StateKey,
		RoomID:   ev.RoomID,
	}
	err := eb.SetContent(ev.Content)
	if err != nil {
		t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
	}
	// make sure the origin_server_ts changes so we can test recency
	time.Sleep(1 * time.Millisecond)
	signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
	if err != nil {
		t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
	}
	h := signedEvent.Headered(roomVer)
	return h
}