Last active
July 27, 2023 01:42
-
-
Save CAFxX/a6fca31790e0dcc773390c1faf2e9f86 to your computer and use it in GitHub Desktop.
Request batcher
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package batchgetter | |
type Getter[I, T any] interface { | |
Get(context.Context, []I) ([]T, error) | |
} | |
type BatchGetter[I, T any] struct { | |
parent Getter[I, T] | |
batchWait time.Duration | |
mu sync.Mutex | |
ctx []context.Context | |
batch []I | |
batchTimer *time.Timer | |
resCh chan struct{} | |
res *result[I, T] | |
} | |
var _ Getter[I, T] = (*BatchGetter[I, T])(nil) | |
type result[I, T any] struct { | |
m map[I]T | |
err error | |
} | |
func (g *BatchGetter[I, T]) Get(ctx context.Context, id []I) ([]T, error) { | |
if g.batchWait <= 0 { | |
return g.parent.Get(ctx, id) | |
} | |
mu.Lock() | |
if batchTimer == nil { | |
g.resCh = make(chan struct{}) | |
g.res = new(result[I, T]) | |
g.batchTimer = time.AfterFunc(g.batchWait, g.get) | |
} | |
g.batch = append(g.batch, id...) | |
g.ctx = append(g.ctx, ctx) | |
res := g.res | |
resCh := g.resCh | |
mu.Unlock() | |
select { | |
case <-ctx.Done(): | |
return nil, ctx.Err() | |
case <-resCh: | |
} | |
if res.err != nil { | |
return nil, res.err | |
} | |
r := make([]T, 0, len(id)) | |
for _, e := range id { | |
r = append(r, res.m[e]) | |
} | |
return r, nil | |
} | |
func (g *BatchGetter[I, T]) get() { | |
g.mu.Lock() | |
batch, ctx, res, resCh := g.batch[:len(g.batch):len(g.batch)], g.ctx[:len(g.ctx):len(g.ctx)], g.res, g.resCh | |
g.batchTimer, g.batch, g.ctx, g.res, g.resCh = nil, g.batch[len(g.batch):], g.ctx[len(g.ctx):], nil, nil | |
g.mu.Unlock() | |
defer close(resCh) | |
defer func() { | |
if r := recover(); r != nil { | |
res.m = nil | |
res.err = fmt.Errorf("panic: %v", r) | |
} | |
}() | |
res.m = make(map[I]T, len(batch)) | |
for _, e := range batch { | |
// TODO: filter out entries for which the context has already expired | |
var zero T | |
res.m[e] = zero | |
} | |
batch = batch[:0] | |
for k := range g.res.m { | |
batch = append(batch, k) | |
} | |
actx, cancel := anyCtx(ctx) | |
if cancel != nil { | |
defer cancel() | |
} | |
res, err := g.parent.Get(actx, batch) | |
for i, e := range res { | |
res.m[i] = e | |
} | |
res.err = err | |
} | |
func anyCtx(ctxs []context.Context) (context.Context, func()) { | |
if len(ctxs) == 0 { | |
return context.Background(), nil | |
} | |
for _, ctx := range ctxs { | |
if ctx == nil { | |
panic("nil context") | |
} | |
if ctx.Done() == nil { | |
return context.Background(), nil | |
} | |
} | |
if len(ctxs) == 1 { | |
return ctxs[0], nil | |
} | |
actx, cancel := context.WithCancel() | |
go func() { | |
for _, ctx := range ctxs { | |
select { | |
case <-actx.Done(): | |
break | |
case <-ctx.Done(): | |
} | |
} | |
cancel() | |
}() | |
return actx, cancel | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment