diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index 0773a82..d45e95c 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -224,6 +224,7 @@ def __call__( ) except Exception as e: + _drain_request_body(environ) return self._handle_error(e, ctx, start_response) def _handle_unary( @@ -468,12 +469,13 @@ def _handle_stream( # been called in time. So we return the response stream as a separate generator # function. This means some duplication of error handling. return _response_stream( - first_response, response_stream, writer, send_trailers, ctx + first_response, environ, response_stream, writer, send_trailers, ctx ) except Exception as e: # Exception before any response message was returned. An error after the first # response message will be handled by _response_stream, so here we have a # full error-only response. + _drain_request_body(environ) _send_stream_response_headers( start_response, protocol, codec, resp_compression.name(), ctx ) @@ -554,21 +556,13 @@ def _request_stream( read_max_bytes: int | None = None, ) -> Iterator[_REQ]: reader = EnvelopeReader(request_class, codec, compression, read_max_bytes) - try: - for chunk in _read_body(environ): - yield from reader.feed(chunk) - except ConnectError: - if environ.get("SERVER_PROTOCOL", "").startswith("HTTP/1"): - # In HTTP/1, the request body should be drained before returning. Generally it's - # best for the application server to handle this, but gunicorn is a famous - # server that doesn't do so, so we go ahead and do it ourselves. - for _ in _read_body(environ): - pass - raise + for chunk in _read_body(environ): + yield from reader.feed(chunk) def _response_stream( first_response: _RES, + environ: WSGIEnvironment, response_stream: Iterator[_RES], writer: EnvelopeWriter, send_trailers: Callable[[list[tuple[str, str]]], None] | None, @@ -583,6 +577,7 @@ def _response_stream( yield body except Exception as e: error = e + _drain_request_body(environ) yield _end_response( writer.end( @@ -638,3 +633,12 @@ def _apply_interceptors( continue func = functools.partial(interceptor.intercept_bidi_stream_sync, func) return replace(endpoint, function=func) + + +def _drain_request_body(environ: WSGIEnvironment) -> None: + if environ.get("SERVER_PROTOCOL", "").startswith("HTTP/1"): + # In HTTP/1, the request body should be drained before returning. Generally it's + # best for the application server to handle this, but gunicorn is a famous + # server that doesn't do so, so we go ahead and do it ourselves. + for _ in _read_body(environ): + pass